sv-task / src /models /fasttext.py
lamossta's picture
models and inference classes
51620d3
import argparse
import json
import tempfile
from pathlib import Path
import fasttext
import numpy as np
from sklearn.metrics import f1_score, classification_report
from src.models.augment import augment, MAXLEN_TO_WINDOW
from src.models.dataset import deduplicate_positions, flatten_to_examples, split_data
from src.schemas.labels import SENTIMENT_LABELS
MODE = "marker"
LABEL_PREFIX = "__label__"
def _to_fasttext_line(example: dict) -> str:
text = example["seg_a"].replace("\n", " ")
label = SENTIMENT_LABELS.id2label[example["label"]]
return f"{LABEL_PREFIX}{label} {text}"
def _write_fasttext_file(examples: list[dict], path: Path) -> None:
with open(path, "w", encoding="utf-8") as f:
for ex in examples:
f.write(_to_fasttext_line(ex) + "\n")
def prepare_data(
data_path: str = "data/data_augmented_256.jsonl",
val_split: float = 0.1,
test_split: float = 0.1,
seed: int = 42,
) -> tuple[list[dict], list[dict], list[dict]]:
with open(data_path, "r", encoding="utf-8") as f:
samples = [json.loads(line) for line in f]
examples = flatten_to_examples(samples, mode=MODE)
train_ex, val_ex, test_ex = split_data(examples, val_split, test_split, seed)
print(f"Train: {len(train_ex)}, Val: {len(val_ex)}, Test: {len(test_ex)}")
return train_ex, val_ex, test_ex
def train(
train_examples: list[dict],
val_examples: list[dict],
output_dir: str = "models/fasttext",
lr: float = 0.5,
epoch: int = 25,
word_ngrams: int = 2,
dim: int = 100,
min_count: int = 1,
) -> fasttext.FastText._FastText:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
train_file = output_dir / "train.txt"
_write_fasttext_file(train_examples, train_file)
model = fasttext.train_supervised(
input=str(train_file),
lr=lr,
epoch=epoch,
wordNgrams=word_ngrams,
dim=dim,
minCount=min_count,
loss="softmax",
)
model.save_model(str(output_dir / "model.bin"))
print(f"Model saved to {output_dir / 'model.bin'}")
evaluate(model, val_examples, split_name="val")
return model
def evaluate(
model: fasttext.FastText._FastText,
examples: list[dict],
split_name: str = "test",
) -> float:
sentiments = list(SENTIMENT_LABELS.classes)
true_labels = []
pred_labels = []
for ex in examples:
text = ex["seg_a"].replace("\n", " ")
prediction = model.predict(text)[0][0].replace(LABEL_PREFIX, "")
pred_labels.append(prediction)
true_labels.append(SENTIMENT_LABELS.id2label[ex["label"]])
macro_f1 = f1_score(true_labels, pred_labels, average="macro", labels=sentiments)
print(f"\n{split_name} (per-position) macro F1: {macro_f1:.4f}")
print(classification_report(true_labels, pred_labels, labels=sentiments, digits=4))
return macro_f1
def evaluate_entity_level(
model: fasttext.FastText._FastText,
examples: list[dict],
split_name: str = "test",
) -> float:
sentiments = list(SENTIMENT_LABELS.classes)
entity_preds: dict[tuple, tuple[str, float]] = {}
entity_labels: dict[tuple, str] = {}
for ex in examples:
key = (ex["sample_id"], ex["entity_id"])
text = ex["seg_a"].replace("\n", " ")
labels, probs = model.predict(text)
label = labels[0].replace(LABEL_PREFIX, "")
conf = float(probs[0])
if key not in entity_preds or conf > entity_preds[key][1]:
entity_preds[key] = (label, conf)
entity_labels[key] = SENTIMENT_LABELS.id2label[ex["label"]]
true = [entity_labels[k] for k in entity_preds]
pred = [entity_preds[k][0] for k in entity_preds]
macro_f1 = f1_score(true, pred, average="macro", labels=sentiments)
print(f"\n{split_name} (entity-level) macro F1: {macro_f1:.4f}")
print(classification_report(true, pred, labels=sentiments, digits=4))
return macro_f1
def predict_samples(
model: fasttext.FastText._FastText,
samples: list[dict],
window_words: int = 70,
deduplicate: bool = False,
) -> list[dict]:
augmented = augment(samples, window_words)
if deduplicate:
augmented = deduplicate_positions(augmented)
examples = flatten_to_examples(augmented, mode=MODE)
entity_preds: dict[tuple, tuple[str, float]] = {}
for ex in examples:
key = (ex["sample_id"], ex["entity_id"])
text = ex["seg_a"].replace("\n", " ")
labels, probs = model.predict(text)
label = labels[0].replace(LABEL_PREFIX, "")
conf = float(probs[0])
if key not in entity_preds or conf > entity_preds[key][1]:
entity_preds[key] = (label, conf)
results = []
for s in samples:
entities_out = []
for e in s["entities"]:
key = (s["id"], e["entity_id"])
entities_out.append({
"entity_id": e["entity_id"],
"entity_text": e["entity_text"],
"classification": entity_preds.get(key, ("neutral", 0.0))[0],
})
results.append({"id": s["id"], "entities": entities_out})
return results
def main():
parser = argparse.ArgumentParser(description="fastText baseline for entity sentiment")
parser.add_argument("--data", default="data/data_augmented_256.jsonl")
parser.add_argument("--output-dir", default="models/fasttext")
parser.add_argument("--lr", type=float, default=0.5)
parser.add_argument("--epoch", type=int, default=25)
parser.add_argument("--word-ngrams", type=int, default=2)
parser.add_argument("--dim", type=int, default=100)
parser.add_argument("--val-split", type=float, default=0.1)
parser.add_argument("--test-split", type=float, default=0.1)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
train_ex, val_ex, test_ex = prepare_data(
args.data, args.val_split, args.test_split, args.seed,
)
model = train(
train_ex, val_ex,
output_dir=args.output_dir,
lr=args.lr,
epoch=args.epoch,
word_ngrams=args.word_ngrams,
dim=args.dim,
)
evaluate(model, test_ex, split_name="test")
evaluate_entity_level(model, test_ex, split_name="test")
if __name__ == "__main__":
main()