Spaces:
Sleeping
Sleeping
| """Training loop for AspectBERT. | |
| - Optimizer: AdamW (lr=2e-5, weight_decay=0.01) | |
| - Scheduler: OneCycleLR (10% warmup, cosine decay) | |
| - Epochs: 4, batch size 16 (CPU) / 32 (GPU) by default | |
| - Metric: macro F1, best checkpoint saved by val F1 | |
| - Logs per-epoch history to results/training_history.json | |
| - Final test evaluation: macro F1, accuracy, per-class F1, confusion matrix, | |
| plus a VADER baseline comparison. | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| import torch | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| from torch.utils.data import DataLoader, Dataset | |
| from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score | |
| from transformers import DistilBertModel, DistilBertTokenizerFast | |
| from constants import ID2LABEL, LABEL2ID, MAX_LENGTH, MODEL_NAME, format_input # noqa: E402 | |
| from model import AspectBERT # noqa: E402 | |
| class AspectDataset(Dataset): | |
| """Loads a jsonl file of {text, aspect, label, ...} rows.""" | |
| def __init__(self, path, tokenizer, max_length=MAX_LENGTH): | |
| self.examples = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| self.examples.append(json.loads(line)) | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.examples) | |
| def __getitem__(self, idx): | |
| row = self.examples[idx] | |
| text = format_input(row["text"], row["aspect"]) | |
| enc = self.tokenizer( | |
| text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ) | |
| item = {k: v.squeeze(0) for k, v in enc.items()} | |
| item["labels"] = torch.tensor(LABEL2ID[row["label"]], dtype=torch.long) | |
| return item | |
| def evaluate(model, loader, device, criterion=None): | |
| model.eval() | |
| all_preds, all_labels = [], [] | |
| total_loss = 0.0 | |
| for batch in loader: | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| logits = model(input_ids, attention_mask) | |
| if criterion is not None: | |
| loss = criterion(logits, labels) | |
| total_loss += loss.item() * labels.size(0) | |
| preds = torch.argmax(logits, dim=-1) | |
| all_preds.extend(preds.cpu().tolist()) | |
| all_labels.extend(labels.cpu().tolist()) | |
| avg_loss = total_loss / len(loader.dataset) if criterion is not None else None | |
| f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0) | |
| acc = accuracy_score(all_labels, all_preds) | |
| return {"loss": avg_loss, "f1": f1, "accuracy": acc, "preds": all_preds, "labels": all_labels} | |
| def save_checkpoint(model, tokenizer, output_dir): | |
| """Save backbone + tokenizer (HF format) and the classifier head separately.""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| model.distilbert.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| torch.save(model.classifier.state_dict(), os.path.join(output_dir, "classifier_head.pt")) | |
| def load_checkpoint(output_dir, device): | |
| model = AspectBERT() | |
| model.distilbert = DistilBertModel.from_pretrained(output_dir) | |
| state_dict = torch.load(os.path.join(output_dir, "classifier_head.pt"), map_location="cpu") | |
| model.classifier.load_state_dict(state_dict) | |
| model.to(device) | |
| return model | |
| def run_vader_baseline(test_file): | |
| from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer | |
| analyzer = SentimentIntensityAnalyzer() | |
| labels, preds = [], [] | |
| with open(test_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| row = json.loads(line) | |
| compound = analyzer.polarity_scores(row["text"])["compound"] | |
| if compound >= 0.05: | |
| pred = "positive" | |
| elif compound <= -0.05: | |
| pred = "negative" | |
| else: | |
| pred = "neutral" | |
| labels.append(LABEL2ID[row["label"]]) | |
| preds.append(LABEL2ID[pred]) | |
| return { | |
| "macro_f1": f1_score(labels, preds, average="macro", zero_division=0), | |
| "accuracy": accuracy_score(labels, preds), | |
| } | |
| def run_test_evaluation(args, tokenizer, device, batch_size): | |
| print("\nLoading best checkpoint for test evaluation...") | |
| model = load_checkpoint(args.output_dir, device) | |
| test_ds = AspectDataset(args.test_file, tokenizer) | |
| test_loader = DataLoader(test_ds, batch_size=batch_size) | |
| metrics = evaluate(model, test_loader, device) | |
| labels_present = sorted(set(metrics["labels"]) | set(metrics["preds"])) | |
| per_class_f1 = f1_score(metrics["labels"], metrics["preds"], average=None, | |
| labels=[0, 1, 2], zero_division=0) | |
| cm = confusion_matrix(metrics["labels"], metrics["preds"], labels=[0, 1, 2]).tolist() | |
| report = classification_report( | |
| metrics["labels"], metrics["preds"], | |
| labels=[0, 1, 2], target_names=[ID2LABEL[i] for i in range(3)], | |
| output_dict=True, zero_division=0, | |
| ) | |
| results = { | |
| "macro_f1": metrics["f1"], | |
| "accuracy": metrics["accuracy"], | |
| "per_class_f1": {ID2LABEL[i]: float(per_class_f1[i]) for i in range(3)}, | |
| "confusion_matrix": cm, | |
| "confusion_matrix_labels": [ID2LABEL[i] for i in range(3)], | |
| "classification_report": report, | |
| } | |
| try: | |
| results["vader_baseline"] = run_vader_baseline(args.test_file) | |
| except ImportError: | |
| print("vaderSentiment not installed; skipping VADER baseline comparison.") | |
| os.makedirs("results", exist_ok=True) | |
| with open("results/test_metrics.json", "w", encoding="utf-8") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\nTest macro F1: {results['macro_f1']:.4f}") | |
| print(f"Test accuracy: {results['accuracy']:.4f}") | |
| print(f"Per-class F1: {results['per_class_f1']}") | |
| if "vader_baseline" in results: | |
| print(f"VADER baseline macro F1: {results['vader_baseline']['macro_f1']:.4f} " | |
| f"(AspectBERT vs VADER on the same test set)") | |
| print("Saved detailed results to results/test_metrics.json") | |
| if labels_present: | |
| pass # labels_present kept for potential debugging/inspection | |
| def train(args): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| batch_size = args.batch_size or (32 if device.type == "cuda" else 16) | |
| print(f"Device: {device}, batch size: {batch_size}") | |
| tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME) | |
| train_ds = AspectDataset(args.train_file, tokenizer) | |
| val_ds = AspectDataset(args.val_file, tokenizer) | |
| print(f"Train examples: {len(train_ds)}, Val examples: {len(val_ds)}") | |
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_ds, batch_size=batch_size) | |
| model = AspectBERT().to(device) | |
| trainable_params = [p for p in model.parameters() if p.requires_grad] | |
| optimizer = AdamW(trainable_params, lr=2e-5, weight_decay=0.01) | |
| total_steps = max(1, len(train_loader) * args.epochs) | |
| scheduler = OneCycleLR( | |
| optimizer, | |
| max_lr=2e-5, | |
| total_steps=total_steps, | |
| pct_start=0.1, | |
| anneal_strategy="cos", | |
| ) | |
| criterion = torch.nn.CrossEntropyLoss() | |
| history = [] | |
| best_f1 = -1.0 | |
| os.makedirs(os.path.dirname(args.history_file) or ".", exist_ok=True) | |
| for epoch in range(1, args.epochs + 1): | |
| model.train() | |
| epoch_loss = 0.0 | |
| for batch in train_loader: | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| optimizer.zero_grad() | |
| logits = model(input_ids, attention_mask) | |
| loss = criterion(logits, labels) | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| epoch_loss += loss.item() * labels.size(0) | |
| train_loss = epoch_loss / len(train_loader.dataset) | |
| val_metrics = evaluate(model, val_loader, device, criterion) | |
| print(f"Epoch {epoch}/{args.epochs} - " | |
| f"train_loss: {train_loss:.4f} - " | |
| f"val_loss: {val_metrics['loss']:.4f} - " | |
| f"val_f1: {val_metrics['f1']:.4f} - " | |
| f"val_acc: {val_metrics['accuracy']:.4f}") | |
| history.append({ | |
| "epoch": epoch, | |
| "train_loss": train_loss, | |
| "val_loss": val_metrics["loss"], | |
| "val_f1": val_metrics["f1"], | |
| "val_accuracy": val_metrics["accuracy"], | |
| "lr": scheduler.get_last_lr()[0], | |
| }) | |
| if val_metrics["f1"] > best_f1: | |
| best_f1 = val_metrics["f1"] | |
| save_checkpoint(model, tokenizer, args.output_dir) | |
| print(f" -> New best val F1: {best_f1:.4f}, checkpoint saved to {args.output_dir}") | |
| with open(args.history_file, "w", encoding="utf-8") as f: | |
| json.dump(history, f, indent=2) | |
| print(f"\nSaved training history to {args.history_file}") | |
| if args.test_file and os.path.exists(args.test_file) and best_f1 >= 0: | |
| run_test_evaluation(args, tokenizer, device, batch_size) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train AspectBERT.") | |
| parser.add_argument("--train_file", default="data/train.jsonl") | |
| parser.add_argument("--val_file", default="data/val.jsonl") | |
| parser.add_argument("--test_file", default="data/test.jsonl") | |
| parser.add_argument("--output_dir", default="models/aspectbert") | |
| parser.add_argument("--history_file", default="results/training_history.json") | |
| parser.add_argument("--epochs", type=int, default=4) | |
| parser.add_argument("--batch_size", type=int, default=None, | |
| help="Defaults to 32 on GPU, 16 on CPU.") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| train(parse_args()) | |