| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any |
| import pandas as pd |
| from tqdm.auto import tqdm |
|
|
| from function_words import FUNCTION_WORDS |
|
|
| FUNCTION_WORD_SET = {word.lower() for word in FUNCTION_WORDS} |
| PLACEHOLDER_RE = r"^<[^<>]+>$" |
|
|
|
|
|
|
| @dataclass(slots=True) |
| class Config: |
| verbose: bool = True |
| include_function_word_rate: bool = True |
| exclude_placeholders_from_avg_word_length: bool = True |
| phrase_role_dependency_labels: tuple[str, ...] = ("acl", "advcl", "ccomp", "pcomp", "relcl", "xcomp") |
| pos_roles: dict[str, tuple[str, ...]] = None |
| dep_roles: dict[str, tuple[str, ...]] = None |
|
|
| |
| def __post_init__(self) -> None: |
| if self.pos_roles is None: |
| self.pos_roles = { |
| "adjective": ("ADJ",), |
| "adposition": ("ADP",), |
| "adverb": ("ADV",), |
| "auxiliary": ("AUX",), |
| "conjunction": ("CONJ",), |
| "coordinating_conjunction": ("CCONJ",), |
| "determiner": ("DET",), |
| "interjection": ("INTJ",), |
| "noun": ("NOUN",), |
| "numeral": ("NUM",), |
| "particle": ("PART",), |
| "pronoun": ("PRON",), |
| "proper_noun": ("PROPN",), |
| "punctuation": ("PUNCT",), |
| "subordinating_conjunction": ("SCONJ",), |
| "symbol": ("SYM",), |
| "verb": ("VERB",), |
| "other": ("X",), |
| "space": ("SPACE",), |
| } |
| if self.dep_roles is None: |
| self.dep_roles = { |
| "root": ("ROOT",), |
| "adjectival_clause": ("acl",), |
| "adjectival_complement": ("acomp",), |
| "adverbial_clause": ("advcl",), |
| "adverbial_modifier": ("advmod",), |
| "agent": ("agent",), |
| "adjectival_modifier": ("amod",), |
| "apposition": ("appos",), |
| "attribute": ("attr",), |
| "auxiliary": ("aux",), |
| "passive_auxiliary": ("auxpass",), |
| "case_marker": ("case",), |
| "coordinating_conjunction": ("cc",), |
| "clausal_complement": ("ccomp",), |
| "compound": ("compound",), |
| "conjunct": ("conj",), |
| "clausal_subject": ("csubj",), |
| "passive_clausal_subject": ("csubjpass",), |
| "dative": ("dative",), |
| "dependency_unspecified": ("dep",), |
| "determiner": ("det",), |
| "direct_object": ("dobj",), |
| "expletive": ("expl",), |
| "indirect_object": ("iobj",), |
| "interjection": ("intj",), |
| "marker": ("mark",), |
| "meta": ("meta",), |
| "negation": ("neg",), |
| "nominal_modifier": ("nmod",), |
| "noun_phrase_adverbial_modifier": ("npadvmod",), |
| "nominal_subject": ("nsubj",), |
| "passive_nominal_subject": ("nsubjpass",), |
| "numeric_modifier": ("nummod",), |
| "object": ("obj",), |
| "object_predicate": ("oprd",), |
| "parataxis": ("parataxis",), |
| "prepositional_complement": ("pcomp",), |
| "object_of_preposition": ("pobj",), |
| "possessive_modifier": ("poss",), |
| "preconjunct": ("preconj",), |
| "predeterminer": ("predet",), |
| "prepositional_modifier": ("prep",), |
| "particle": ("prt",), |
| "punctuation": ("punct",), |
| "quantifier_modifier": ("quantmod",), |
| "relative_clause_modifier": ("relcl",), |
| "open_clausal_complement": ("xcomp",), |
| } |
| config = Config() |
|
|
|
|
|
|
| def _safe_mean(values: list[int]) -> float: |
| if not values: |
| return 0.0 |
| return round(sum(values) / len(values), 3) |
|
|
| def _safe_rate(count: int, total: int) -> float: |
| if total == 0: |
| return 0.0 |
| return round(count / total, 3) |
|
|
| |
| def _word_token_indices(record: dict[str, Any]) -> list[int]: |
| return [ |
| index for index, (is_punct, is_space) in enumerate(zip(record["token_is_punct"], record["token_is_space"], strict=False)) |
| if not is_punct and not is_space |
| ] |
|
|
| |
| def _avg_word_length(record: dict[str, Any], config: Config = config) -> float: |
| lengths: list[int] = [] |
| for index, token_text in enumerate(record["tokens"]): |
| if record["token_is_punct"][index] or record["token_is_space"][index]: |
| continue |
| if config.exclude_placeholders_from_avg_word_length and (token_text.startswith("<") and token_text.endswith(">")): |
| continue |
| lengths.append(len(token_text)) |
| return _safe_mean(lengths) |
|
|
| |
|
|
| |
| def _sentence_spans(record: dict[str, Any]) -> list[tuple[int, int]]: |
| spans = record["sentence_token_spans"] |
| |
| if spans: return spans |
| if record["tokens"]: return [(0, len(record["tokens"]))] |
| return [] |
|
|
| def _sentence_word_lengths(record: dict[str, Any]) -> list[int]: |
| word_indices = set(_word_token_indices(record)) |
| sentence_lengths: list[int] = [] |
| for start, end in _sentence_spans(record): |
| count = sum(1 for index in range(start, end) if index in word_indices) |
| sentence_lengths.append(count) |
| return sentence_lengths |
|
|
| def _sentence_function_word_counts(record: dict[str, Any]) -> list[int]: |
| sentence_counts: list[int] = [] |
| for start, end in _sentence_spans(record): |
| count = 0 |
| for index in range(start, end): |
| if record["token_is_punct"][index] or record["token_is_space"][index]: |
| continue |
| if record["token_lower"][index] in FUNCTION_WORD_SET: |
| count += 1 |
| sentence_counts.append(count) |
| return sentence_counts |
|
|
| |
|
|
| def _phrase_role_features(record: dict[str, Any], config: Config = config) -> dict[str, float]: |
|
|
| noun_phrase_count = len(record["noun_chunk_spans"]) |
|
|
| dependency_labels = record["token_dep"] |
| prepositional_phrase_count = sum(1 for label in dependency_labels if label == "prep") |
|
|
| clausal_phrase_count = sum( |
| 1 for label in dependency_labels if label in config.phrase_role_dependency_labels |
| ) |
|
|
| phrase_counts = { |
| "phrase_noun_phrase_rate": noun_phrase_count, |
| "phrase_prepositional_phrase_rate": prepositional_phrase_count, |
| "phrase_clausal_phrase_rate": clausal_phrase_count, |
| } |
| total_phrase_units = sum(phrase_counts.values()) |
|
|
| |
| return { |
| feature_name: _safe_rate(count, total_phrase_units) |
| for feature_name, count in phrase_counts.items() |
| } |
|
|
| |
|
|
| def _pos_role_features(record: dict[str, Any], config: Config = config) -> dict[str, float]: |
| word_indices = _word_token_indices(record) |
| pos_counts = {name: 0 for name in config.pos_roles} |
| for index in word_indices: |
| token_pos = record["token_pos"][index] |
| for role_name, labels in config.pos_roles.items(): |
| if token_pos in labels: |
| pos_counts[role_name] += 1 |
| total_pos_units = sum(pos_counts.values()) |
| return { |
| f"pos_{role_name}_rate": _safe_rate(count, total_pos_units) |
| for role_name, count in pos_counts.items() |
| } |
|
|
| |
|
|
| def _dep_role_features(record: dict[str, Any], config: Config = config) -> dict[str, float]: |
| word_indices = _word_token_indices(record) |
| dep_counts = {name: 0 for name in config.dep_roles} |
|
|
| for index in word_indices: |
| token_dep = record["token_dep"][index] |
| for role_name, labels in config.dep_roles.items(): |
| if token_dep in labels: |
| dep_counts[role_name] += 1 |
|
|
| total_dep_units = sum(dep_counts.values()) |
| return { |
| f"dep_{role_name}_rate": _safe_rate(count, total_dep_units) |
| for role_name, count in dep_counts.items() |
| } |
|
|
|
|
| |
|
|
|
|
| def extract_document_statistics(record: dict[str, Any], config: Config = config) -> dict[str, float]: |
|
|
| word_indices = _word_token_indices(record) |
| total_word_tokens = len(word_indices) |
| total_non_space_tokens = sum(1 for is_space in record["token_is_space"] if not is_space) |
| total_punct_tokens = sum(1 for is_punct in record["token_is_punct"] if is_punct) |
| total_function_words = sum(1 for index in word_indices if record["token_lower"][index] in FUNCTION_WORD_SET) |
|
|
| |
| sentence_lengths = _sentence_word_lengths(record) |
| sentence_function_word_counts = _sentence_function_word_counts(record) |
|
|
| features: dict[str, float] = { |
| "avg_sentence_length_words": _safe_mean(sentence_lengths), |
| "avg_function_words_per_sentence": _safe_mean(sentence_function_word_counts), |
| "punctuation_rate": _safe_rate(total_punct_tokens, total_non_space_tokens), |
| "avg_word_length": _avg_word_length(record, config=config), |
| } |
|
|
| if config.include_function_word_rate: |
| features["function_word_rate"] = _safe_rate(total_function_words, total_word_tokens) |
|
|
| features.update(_phrase_role_features(record, config=config)) |
| features.update(_pos_role_features(record, config=config)) |
| features.update(_dep_role_features(record, config=config)) |
|
|
| return features |
|
|
|
|
|
|
| def extract_split_statistics( |
| df: pd.DataFrame, |
| split_cache: dict[str, list[dict[str, Any]]], |
| split_name: str = "", |
| config: Config = config, |
| ) -> pd.DataFrame: |
| """ |
| Append statistical features for each configured text column in one split. |
| """ |
| result = df.copy() |
|
|
| for column in ["text1", "text2"]: |
|
|
| records = split_cache[column] |
| iterator = records |
| iterator = tqdm( |
| records, |
| total=len(records), |
| desc=f"Stat features [{split_name}:{column}]", |
| ) |
|
|
| feature_rows = [extract_document_statistics(record, config=config) for record in iterator] |
| feature_df = pd.DataFrame(feature_rows).add_prefix(f"{column}_") |
| result = pd.concat([result.reset_index(drop=True), feature_df.reset_index(drop=True)], axis=1) |
|
|
| return result |
|
|
|
|
|
|
| def build_feature_summary( |
| dict_df: dict[str, pd.DataFrame], |
| config: Config = config, |
| ) -> pd.DataFrame: |
| rows: list[dict[str, Any]] = [] |
| summary_columns = [ |
| "avg_sentence_length_words", |
| "avg_function_words_per_sentence", |
| "function_word_rate", |
| "punctuation_rate", |
| "avg_word_length", |
| ] |
| for split, df in dict_df.items(): |
| row: dict[str, Any] = {"split": split, "num_rows": len(df)} |
| for column in ["text1", "text2"]: |
| for feature_name in summary_columns: |
| prefixed_name = f"{column}_{feature_name}" |
| if prefixed_name in df.columns: |
| row[f"{prefixed_name}_mean"] = round(df[prefixed_name].mean(), 6) |
| rows.append(row) |
|
|
| return pd.DataFrame(rows) |
|
|
|
|
|
|
| def statistical_features_wrapper( |
| dict_df: dict[str, pd.DataFrame], |
| linguistic_cache: dict[str, dict[str, list[dict[str, Any]]]], |
| config: Config = config, |
| ) -> tuple[dict[str, pd.DataFrame], pd.DataFrame]: |
|
|
| if config.verbose: |
| print("======= STATISTICAL FEATURES START =======") |
|
|
| statistical_dict_df: dict[str, pd.DataFrame] = {} |
|
|
| for split, df in dict_df.items(): |
|
|
| if config.verbose: |
| print(f"\nProcessing statistical features for split='{split}' ({len(df):,} rows)") |
|
|
| statistical_dict_df[split] = extract_split_statistics( |
| df, |
| split_cache=linguistic_cache[split], |
| split_name=split, |
| config=config, |
| ) |
|
|
| statistical_summary_df = build_feature_summary(statistical_dict_df, config=config) |
|
|
| if config.verbose: |
| print("\nStatistical feature summary:") |
| print(statistical_summary_df) |
| print("") |
| print("======= STATISTICAL FEATURES END =======") |
| print("") |
|
|
| return statistical_dict_df, statistical_summary_df |
|
|