| import argparse |
| import json |
| import tempfile |
| from pathlib import Path |
| import fasttext |
| import numpy as np |
| from sklearn.metrics import f1_score, classification_report |
| from src.models.augment import augment, MAXLEN_TO_WINDOW |
| from src.models.dataset import deduplicate_positions, flatten_to_examples, split_data |
| from src.schemas.labels import SENTIMENT_LABELS |
|
|
| MODE = "marker" |
| LABEL_PREFIX = "__label__" |
|
|
|
|
| def _to_fasttext_line(example: dict) -> str: |
| text = example["seg_a"].replace("\n", " ") |
| label = SENTIMENT_LABELS.id2label[example["label"]] |
| return f"{LABEL_PREFIX}{label} {text}" |
|
|
|
|
| def _write_fasttext_file(examples: list[dict], path: Path) -> None: |
| with open(path, "w", encoding="utf-8") as f: |
| for ex in examples: |
| f.write(_to_fasttext_line(ex) + "\n") |
|
|
|
|
| def prepare_data( |
| data_path: str = "data/data_augmented_256.jsonl", |
| val_split: float = 0.1, |
| test_split: float = 0.1, |
| seed: int = 42, |
| ) -> tuple[list[dict], list[dict], list[dict]]: |
| with open(data_path, "r", encoding="utf-8") as f: |
| samples = [json.loads(line) for line in f] |
|
|
| examples = flatten_to_examples(samples, mode=MODE) |
| train_ex, val_ex, test_ex = split_data(examples, val_split, test_split, seed) |
|
|
| print(f"Train: {len(train_ex)}, Val: {len(val_ex)}, Test: {len(test_ex)}") |
| return train_ex, val_ex, test_ex |
|
|
|
|
| def train( |
| train_examples: list[dict], |
| val_examples: list[dict], |
| output_dir: str = "models/fasttext", |
| lr: float = 0.5, |
| epoch: int = 25, |
| word_ngrams: int = 2, |
| dim: int = 100, |
| min_count: int = 1, |
| ) -> fasttext.FastText._FastText: |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| train_file = output_dir / "train.txt" |
| _write_fasttext_file(train_examples, train_file) |
|
|
| model = fasttext.train_supervised( |
| input=str(train_file), |
| lr=lr, |
| epoch=epoch, |
| wordNgrams=word_ngrams, |
| dim=dim, |
| minCount=min_count, |
| loss="softmax", |
| ) |
|
|
| model.save_model(str(output_dir / "model.bin")) |
| print(f"Model saved to {output_dir / 'model.bin'}") |
|
|
| evaluate(model, val_examples, split_name="val") |
|
|
| return model |
|
|
|
|
| def evaluate( |
| model: fasttext.FastText._FastText, |
| examples: list[dict], |
| split_name: str = "test", |
| ) -> float: |
| sentiments = list(SENTIMENT_LABELS.classes) |
| true_labels = [] |
| pred_labels = [] |
|
|
| for ex in examples: |
| text = ex["seg_a"].replace("\n", " ") |
| prediction = model.predict(text)[0][0].replace(LABEL_PREFIX, "") |
| pred_labels.append(prediction) |
| true_labels.append(SENTIMENT_LABELS.id2label[ex["label"]]) |
|
|
| macro_f1 = f1_score(true_labels, pred_labels, average="macro", labels=sentiments) |
| print(f"\n{split_name} (per-position) macro F1: {macro_f1:.4f}") |
| print(classification_report(true_labels, pred_labels, labels=sentiments, digits=4)) |
|
|
| return macro_f1 |
|
|
|
|
| def evaluate_entity_level( |
| model: fasttext.FastText._FastText, |
| examples: list[dict], |
| split_name: str = "test", |
| ) -> float: |
| sentiments = list(SENTIMENT_LABELS.classes) |
|
|
| entity_preds: dict[tuple, tuple[str, float]] = {} |
| entity_labels: dict[tuple, str] = {} |
|
|
| for ex in examples: |
| key = (ex["sample_id"], ex["entity_id"]) |
| text = ex["seg_a"].replace("\n", " ") |
| labels, probs = model.predict(text) |
| label = labels[0].replace(LABEL_PREFIX, "") |
| conf = float(probs[0]) |
| if key not in entity_preds or conf > entity_preds[key][1]: |
| entity_preds[key] = (label, conf) |
| entity_labels[key] = SENTIMENT_LABELS.id2label[ex["label"]] |
|
|
| true = [entity_labels[k] for k in entity_preds] |
| pred = [entity_preds[k][0] for k in entity_preds] |
|
|
| macro_f1 = f1_score(true, pred, average="macro", labels=sentiments) |
| print(f"\n{split_name} (entity-level) macro F1: {macro_f1:.4f}") |
| print(classification_report(true, pred, labels=sentiments, digits=4)) |
|
|
| return macro_f1 |
|
|
|
|
| def predict_samples( |
| model: fasttext.FastText._FastText, |
| samples: list[dict], |
| window_words: int = 70, |
| deduplicate: bool = False, |
| ) -> list[dict]: |
| augmented = augment(samples, window_words) |
| if deduplicate: |
| augmented = deduplicate_positions(augmented) |
| examples = flatten_to_examples(augmented, mode=MODE) |
|
|
| entity_preds: dict[tuple, tuple[str, float]] = {} |
| for ex in examples: |
| key = (ex["sample_id"], ex["entity_id"]) |
| text = ex["seg_a"].replace("\n", " ") |
| labels, probs = model.predict(text) |
| label = labels[0].replace(LABEL_PREFIX, "") |
| conf = float(probs[0]) |
| if key not in entity_preds or conf > entity_preds[key][1]: |
| entity_preds[key] = (label, conf) |
|
|
| results = [] |
| for s in samples: |
| entities_out = [] |
| for e in s["entities"]: |
| key = (s["id"], e["entity_id"]) |
| entities_out.append({ |
| "entity_id": e["entity_id"], |
| "entity_text": e["entity_text"], |
| "classification": entity_preds.get(key, ("neutral", 0.0))[0], |
| }) |
| results.append({"id": s["id"], "entities": entities_out}) |
|
|
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="fastText baseline for entity sentiment") |
| parser.add_argument("--data", default="data/data_augmented_256.jsonl") |
| parser.add_argument("--output-dir", default="models/fasttext") |
| parser.add_argument("--lr", type=float, default=0.5) |
| parser.add_argument("--epoch", type=int, default=25) |
| parser.add_argument("--word-ngrams", type=int, default=2) |
| parser.add_argument("--dim", type=int, default=100) |
| parser.add_argument("--val-split", type=float, default=0.1) |
| parser.add_argument("--test-split", type=float, default=0.1) |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| train_ex, val_ex, test_ex = prepare_data( |
| args.data, args.val_split, args.test_split, args.seed, |
| ) |
|
|
| model = train( |
| train_ex, val_ex, |
| output_dir=args.output_dir, |
| lr=args.lr, |
| epoch=args.epoch, |
| word_ngrams=args.word_ngrams, |
| dim=args.dim, |
| ) |
|
|
| evaluate(model, test_ex, split_name="test") |
| evaluate_entity_level(model, test_ex, split_name="test") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|