import argparse import json import tempfile from pathlib import Path import fasttext import numpy as np from sklearn.metrics import f1_score, classification_report from src.models.augment import augment, MAXLEN_TO_WINDOW from src.models.dataset import deduplicate_positions, flatten_to_examples, split_data from src.schemas.labels import SENTIMENT_LABELS MODE = "marker" LABEL_PREFIX = "__label__" def _to_fasttext_line(example: dict) -> str: text = example["seg_a"].replace("\n", " ") label = SENTIMENT_LABELS.id2label[example["label"]] return f"{LABEL_PREFIX}{label} {text}" def _write_fasttext_file(examples: list[dict], path: Path) -> None: with open(path, "w", encoding="utf-8") as f: for ex in examples: f.write(_to_fasttext_line(ex) + "\n") def prepare_data( data_path: str = "data/data_augmented_256.jsonl", val_split: float = 0.1, test_split: float = 0.1, seed: int = 42, ) -> tuple[list[dict], list[dict], list[dict]]: with open(data_path, "r", encoding="utf-8") as f: samples = [json.loads(line) for line in f] examples = flatten_to_examples(samples, mode=MODE) train_ex, val_ex, test_ex = split_data(examples, val_split, test_split, seed) print(f"Train: {len(train_ex)}, Val: {len(val_ex)}, Test: {len(test_ex)}") return train_ex, val_ex, test_ex def train( train_examples: list[dict], val_examples: list[dict], output_dir: str = "models/fasttext", lr: float = 0.5, epoch: int = 25, word_ngrams: int = 2, dim: int = 100, min_count: int = 1, ) -> fasttext.FastText._FastText: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) train_file = output_dir / "train.txt" _write_fasttext_file(train_examples, train_file) model = fasttext.train_supervised( input=str(train_file), lr=lr, epoch=epoch, wordNgrams=word_ngrams, dim=dim, minCount=min_count, loss="softmax", ) model.save_model(str(output_dir / "model.bin")) print(f"Model saved to {output_dir / 'model.bin'}") evaluate(model, val_examples, split_name="val") return model def evaluate( model: fasttext.FastText._FastText, examples: list[dict], split_name: str = "test", ) -> float: sentiments = list(SENTIMENT_LABELS.classes) true_labels = [] pred_labels = [] for ex in examples: text = ex["seg_a"].replace("\n", " ") prediction = model.predict(text)[0][0].replace(LABEL_PREFIX, "") pred_labels.append(prediction) true_labels.append(SENTIMENT_LABELS.id2label[ex["label"]]) macro_f1 = f1_score(true_labels, pred_labels, average="macro", labels=sentiments) print(f"\n{split_name} (per-position) macro F1: {macro_f1:.4f}") print(classification_report(true_labels, pred_labels, labels=sentiments, digits=4)) return macro_f1 def evaluate_entity_level( model: fasttext.FastText._FastText, examples: list[dict], split_name: str = "test", ) -> float: sentiments = list(SENTIMENT_LABELS.classes) entity_preds: dict[tuple, tuple[str, float]] = {} entity_labels: dict[tuple, str] = {} for ex in examples: key = (ex["sample_id"], ex["entity_id"]) text = ex["seg_a"].replace("\n", " ") labels, probs = model.predict(text) label = labels[0].replace(LABEL_PREFIX, "") conf = float(probs[0]) if key not in entity_preds or conf > entity_preds[key][1]: entity_preds[key] = (label, conf) entity_labels[key] = SENTIMENT_LABELS.id2label[ex["label"]] true = [entity_labels[k] for k in entity_preds] pred = [entity_preds[k][0] for k in entity_preds] macro_f1 = f1_score(true, pred, average="macro", labels=sentiments) print(f"\n{split_name} (entity-level) macro F1: {macro_f1:.4f}") print(classification_report(true, pred, labels=sentiments, digits=4)) return macro_f1 def predict_samples( model: fasttext.FastText._FastText, samples: list[dict], window_words: int = 70, deduplicate: bool = False, ) -> list[dict]: augmented = augment(samples, window_words) if deduplicate: augmented = deduplicate_positions(augmented) examples = flatten_to_examples(augmented, mode=MODE) entity_preds: dict[tuple, tuple[str, float]] = {} for ex in examples: key = (ex["sample_id"], ex["entity_id"]) text = ex["seg_a"].replace("\n", " ") labels, probs = model.predict(text) label = labels[0].replace(LABEL_PREFIX, "") conf = float(probs[0]) if key not in entity_preds or conf > entity_preds[key][1]: entity_preds[key] = (label, conf) results = [] for s in samples: entities_out = [] for e in s["entities"]: key = (s["id"], e["entity_id"]) entities_out.append({ "entity_id": e["entity_id"], "entity_text": e["entity_text"], "classification": entity_preds.get(key, ("neutral", 0.0))[0], }) results.append({"id": s["id"], "entities": entities_out}) return results def main(): parser = argparse.ArgumentParser(description="fastText baseline for entity sentiment") parser.add_argument("--data", default="data/data_augmented_256.jsonl") parser.add_argument("--output-dir", default="models/fasttext") parser.add_argument("--lr", type=float, default=0.5) parser.add_argument("--epoch", type=int, default=25) parser.add_argument("--word-ngrams", type=int, default=2) parser.add_argument("--dim", type=int, default=100) parser.add_argument("--val-split", type=float, default=0.1) parser.add_argument("--test-split", type=float, default=0.1) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() train_ex, val_ex, test_ex = prepare_data( args.data, args.val_split, args.test_split, args.seed, ) model = train( train_ex, val_ex, output_dir=args.output_dir, lr=args.lr, epoch=args.epoch, word_ngrams=args.word_ngrams, dim=args.dim, ) evaluate(model, test_ex, split_name="test") evaluate_entity_level(model, test_ex, split_name="test") if __name__ == "__main__": main()