AVeri / src /features_statistical.py
salirafi's picture
Upload 14 files
66242b8 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import pandas as pd
from tqdm.auto import tqdm
from function_words import FUNCTION_WORDS
FUNCTION_WORD_SET = {word.lower() for word in FUNCTION_WORDS} # lowercase function words
PLACEHOLDER_RE = r"^<[^<>]+>$" # for masking placeholder
@dataclass(slots=True)
class Config:
verbose: bool = True
include_function_word_rate: bool = True
exclude_placeholders_from_avg_word_length: bool = True
phrase_role_dependency_labels: tuple[str, ...] = ("acl", "advcl", "ccomp", "pcomp", "relcl", "xcomp") # for clausal phrase signals
pos_roles: dict[str, tuple[str, ...]] = None
dep_roles: dict[str, tuple[str, ...]] = None
# covering tags for en_core_web_lg spacy model
def __post_init__(self) -> None:
if self.pos_roles is None:
self.pos_roles = {
"adjective": ("ADJ",),
"adposition": ("ADP",),
"adverb": ("ADV",),
"auxiliary": ("AUX",),
"conjunction": ("CONJ",),
"coordinating_conjunction": ("CCONJ",),
"determiner": ("DET",),
"interjection": ("INTJ",),
"noun": ("NOUN",),
"numeral": ("NUM",),
"particle": ("PART",),
"pronoun": ("PRON",),
"proper_noun": ("PROPN",),
"punctuation": ("PUNCT",),
"subordinating_conjunction": ("SCONJ",),
"symbol": ("SYM",),
"verb": ("VERB",),
"other": ("X",),
"space": ("SPACE",),
}
if self.dep_roles is None:
self.dep_roles = {
"root": ("ROOT",),
"adjectival_clause": ("acl",),
"adjectival_complement": ("acomp",),
"adverbial_clause": ("advcl",),
"adverbial_modifier": ("advmod",),
"agent": ("agent",),
"adjectival_modifier": ("amod",),
"apposition": ("appos",),
"attribute": ("attr",),
"auxiliary": ("aux",),
"passive_auxiliary": ("auxpass",),
"case_marker": ("case",),
"coordinating_conjunction": ("cc",),
"clausal_complement": ("ccomp",),
"compound": ("compound",),
"conjunct": ("conj",),
"clausal_subject": ("csubj",),
"passive_clausal_subject": ("csubjpass",),
"dative": ("dative",),
"dependency_unspecified": ("dep",),
"determiner": ("det",),
"direct_object": ("dobj",),
"expletive": ("expl",),
"indirect_object": ("iobj",),
"interjection": ("intj",),
"marker": ("mark",),
"meta": ("meta",),
"negation": ("neg",),
"nominal_modifier": ("nmod",),
"noun_phrase_adverbial_modifier": ("npadvmod",),
"nominal_subject": ("nsubj",),
"passive_nominal_subject": ("nsubjpass",),
"numeric_modifier": ("nummod",),
"object": ("obj",),
"object_predicate": ("oprd",),
"parataxis": ("parataxis",),
"prepositional_complement": ("pcomp",),
"object_of_preposition": ("pobj",),
"possessive_modifier": ("poss",),
"preconjunct": ("preconj",),
"predeterminer": ("predet",),
"prepositional_modifier": ("prep",),
"particle": ("prt",),
"punctuation": ("punct",),
"quantifier_modifier": ("quantmod",),
"relative_clause_modifier": ("relcl",),
"open_clausal_complement": ("xcomp",),
}
config = Config()
def _safe_mean(values: list[int]) -> float:
if not values:
return 0.0
return round(sum(values) / len(values), 3)
def _safe_rate(count: int, total: int) -> float:
if total == 0:
return 0.0
return round(count / total, 3)
# flagging non-word tokens including both punctuations and white-space
def _word_token_indices(record: dict[str, Any]) -> list[int]:
return [
index for index, (is_punct, is_space) in enumerate(zip(record["token_is_punct"], record["token_is_space"], strict=False))
if not is_punct and not is_space
]
# count average number of characters per word
def _avg_word_length(record: dict[str, Any], config: Config = config) -> float:
lengths: list[int] = []
for index, token_text in enumerate(record["tokens"]):
if record["token_is_punct"][index] or record["token_is_space"][index]: # excluding punctuations and white spaces
continue
if config.exclude_placeholders_from_avg_word_length and (token_text.startswith("<") and token_text.endswith(">")): # excluding <...> placeholder from previous masking
continue
lengths.append(len(token_text))
return _safe_mean(lengths)
# ========== SENTENCE STATISTICS ============
# defining a sentence
def _sentence_spans(record: dict[str, Any]) -> list[tuple[int, int]]:
spans = record["sentence_token_spans"]
# either compute from sentence_token_spans or number of tokens
if spans: return spans
if record["tokens"]: return [(0, len(record["tokens"]))]
return [] # else no sentence
def _sentence_word_lengths(record: dict[str, Any]) -> list[int]:
word_indices = set(_word_token_indices(record))
sentence_lengths: list[int] = []
for start, end in _sentence_spans(record):
count = sum(1 for index in range(start, end) if index in word_indices)
sentence_lengths.append(count)
return sentence_lengths
def _sentence_function_word_counts(record: dict[str, Any]) -> list[int]:
sentence_counts: list[int] = []
for start, end in _sentence_spans(record):
count = 0
for index in range(start, end):
if record["token_is_punct"][index] or record["token_is_space"][index]:
continue
if record["token_lower"][index] in FUNCTION_WORD_SET:
count += 1
sentence_counts.append(count)
return sentence_counts
# ============ PHRASE STATISTICS ==============
def _phrase_role_features(record: dict[str, Any], config: Config = config) -> dict[str, float]:
noun_phrase_count = len(record["noun_chunk_spans"]) # count noun phrases
dependency_labels = record["token_dep"]
prepositional_phrase_count = sum(1 for label in dependency_labels if label == "prep") # count preprositional phrases (approximated from "prep" label)
clausal_phrase_count = sum(
1 for label in dependency_labels if label in config.phrase_role_dependency_labels
) # count clausal phrases
phrase_counts = {
"phrase_noun_phrase_rate": noun_phrase_count,
"phrase_prepositional_phrase_rate": prepositional_phrase_count,
"phrase_clausal_phrase_rate": clausal_phrase_count,
}
total_phrase_units = sum(phrase_counts.values())
# returning the rate for each phrase group
return {
feature_name: _safe_rate(count, total_phrase_units)
for feature_name, count in phrase_counts.items()
}
# ============ POS TAGS STATISTICS ==============
def _pos_role_features(record: dict[str, Any], config: Config = config) -> dict[str, float]:
word_indices = _word_token_indices(record)
pos_counts = {name: 0 for name in config.pos_roles}
for index in word_indices:
token_pos = record["token_pos"][index]
for role_name, labels in config.pos_roles.items():
if token_pos in labels:
pos_counts[role_name] += 1
total_pos_units = sum(pos_counts.values())
return {
f"pos_{role_name}_rate": _safe_rate(count, total_pos_units)
for role_name, count in pos_counts.items()
}
# ============ DEPENDENCY LABEL STATISTICS ==============
def _dep_role_features(record: dict[str, Any], config: Config = config) -> dict[str, float]:
word_indices = _word_token_indices(record)
dep_counts = {name: 0 for name in config.dep_roles}
for index in word_indices:
token_dep = record["token_dep"][index]
for role_name, labels in config.dep_roles.items():
if token_dep in labels:
dep_counts[role_name] += 1
total_dep_units = sum(dep_counts.values())
return {
f"dep_{role_name}_rate": _safe_rate(count, total_dep_units)
for role_name, count in dep_counts.items()
}
# ======================================================
def extract_document_statistics(record: dict[str, Any], config: Config = config) -> dict[str, float]:
word_indices = _word_token_indices(record) # flag non-word tokens including both punctuations and white-space
total_word_tokens = len(word_indices) # number of word tokens
total_non_space_tokens = sum(1 for is_space in record["token_is_space"] if not is_space) # flag non-word tokens for only white-space
total_punct_tokens = sum(1 for is_punct in record["token_is_punct"] if is_punct) # count punctuations
total_function_words = sum(1 for index in word_indices if record["token_lower"][index] in FUNCTION_WORD_SET)
# count sentence statistics
sentence_lengths = _sentence_word_lengths(record)
sentence_function_word_counts = _sentence_function_word_counts(record)
features: dict[str, float] = {
"avg_sentence_length_words": _safe_mean(sentence_lengths),
"avg_function_words_per_sentence": _safe_mean(sentence_function_word_counts),
"punctuation_rate": _safe_rate(total_punct_tokens, total_non_space_tokens),
"avg_word_length": _avg_word_length(record, config=config),
}
if config.include_function_word_rate:
features["function_word_rate"] = _safe_rate(total_function_words, total_word_tokens)
features.update(_phrase_role_features(record, config=config)) # compute rate of phrases (from three groups)
features.update(_pos_role_features(record, config=config)) # compute rate for POS tags
features.update(_dep_role_features(record, config=config)) # compute rate for dependency features
return features
def extract_split_statistics(
df: pd.DataFrame,
split_cache: dict[str, list[dict[str, Any]]],
split_name: str = "",
config: Config = config,
) -> pd.DataFrame:
"""
Append statistical features for each configured text column in one split.
"""
result = df.copy()
for column in ["text1", "text2"]:
records = split_cache[column] # linguistic_cache must contain "text1" and "text2"
iterator = records
iterator = tqdm(
records,
total=len(records),
desc=f"Stat features [{split_name}:{column}]",
)
feature_rows = [extract_document_statistics(record, config=config) for record in iterator] # loop over rows
feature_df = pd.DataFrame(feature_rows).add_prefix(f"{column}_") # make each feature as one column
result = pd.concat([result.reset_index(drop=True), feature_df.reset_index(drop=True)], axis=1)
return result
def build_feature_summary(
dict_df: dict[str, pd.DataFrame],
config: Config = config,
) -> pd.DataFrame:
rows: list[dict[str, Any]] = []
summary_columns = [
"avg_sentence_length_words",
"avg_function_words_per_sentence",
"function_word_rate",
"punctuation_rate",
"avg_word_length",
]
for split, df in dict_df.items():
row: dict[str, Any] = {"split": split, "num_rows": len(df)}
for column in ["text1", "text2"]:
for feature_name in summary_columns:
prefixed_name = f"{column}_{feature_name}"
if prefixed_name in df.columns:
row[f"{prefixed_name}_mean"] = round(df[prefixed_name].mean(), 6)
rows.append(row)
return pd.DataFrame(rows)
def statistical_features_wrapper(
dict_df: dict[str, pd.DataFrame],
linguistic_cache: dict[str, dict[str, list[dict[str, Any]]]],
config: Config = config,
) -> tuple[dict[str, pd.DataFrame], pd.DataFrame]:
if config.verbose:
print("======= STATISTICAL FEATURES START =======")
statistical_dict_df: dict[str, pd.DataFrame] = {}
for split, df in dict_df.items():
if config.verbose:
print(f"\nProcessing statistical features for split='{split}' ({len(df):,} rows)")
statistical_dict_df[split] = extract_split_statistics(
df,
split_cache=linguistic_cache[split],
split_name=split,
config=config,
)
statistical_summary_df = build_feature_summary(statistical_dict_df, config=config)
if config.verbose:
print("\nStatistical feature summary:")
print(statistical_summary_df)
print("")
print("======= STATISTICAL FEATURES END =======")
print("")
return statistical_dict_df, statistical_summary_df