AspectBERT / src /train.py
itismeTithi's picture
Deploy AspectBERT Streamlit app
31f6bcb
raw
history blame contribute delete
10.2 kB
"""Training loop for AspectBERT.
- Optimizer: AdamW (lr=2e-5, weight_decay=0.01)
- Scheduler: OneCycleLR (10% warmup, cosine decay)
- Epochs: 4, batch size 16 (CPU) / 32 (GPU) by default
- Metric: macro F1, best checkpoint saved by val F1
- Logs per-epoch history to results/training_history.json
- Final test evaluation: macro F1, accuracy, per-class F1, confusion matrix,
plus a VADER baseline comparison.
"""
import argparse
import json
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from transformers import DistilBertModel, DistilBertTokenizerFast
from constants import ID2LABEL, LABEL2ID, MAX_LENGTH, MODEL_NAME, format_input # noqa: E402
from model import AspectBERT # noqa: E402
class AspectDataset(Dataset):
"""Loads a jsonl file of {text, aspect, label, ...} rows."""
def __init__(self, path, tokenizer, max_length=MAX_LENGTH):
self.examples = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self.examples.append(json.loads(line))
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
row = self.examples[idx]
text = format_input(row["text"], row["aspect"])
enc = self.tokenizer(
text,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt",
)
item = {k: v.squeeze(0) for k, v in enc.items()}
item["labels"] = torch.tensor(LABEL2ID[row["label"]], dtype=torch.long)
return item
@torch.no_grad()
def evaluate(model, loader, device, criterion=None):
model.eval()
all_preds, all_labels = [], []
total_loss = 0.0
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
logits = model(input_ids, attention_mask)
if criterion is not None:
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
preds = torch.argmax(logits, dim=-1)
all_preds.extend(preds.cpu().tolist())
all_labels.extend(labels.cpu().tolist())
avg_loss = total_loss / len(loader.dataset) if criterion is not None else None
f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
acc = accuracy_score(all_labels, all_preds)
return {"loss": avg_loss, "f1": f1, "accuracy": acc, "preds": all_preds, "labels": all_labels}
def save_checkpoint(model, tokenizer, output_dir):
"""Save backbone + tokenizer (HF format) and the classifier head separately."""
os.makedirs(output_dir, exist_ok=True)
model.distilbert.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(model.classifier.state_dict(), os.path.join(output_dir, "classifier_head.pt"))
def load_checkpoint(output_dir, device):
model = AspectBERT()
model.distilbert = DistilBertModel.from_pretrained(output_dir)
state_dict = torch.load(os.path.join(output_dir, "classifier_head.pt"), map_location="cpu")
model.classifier.load_state_dict(state_dict)
model.to(device)
return model
def run_vader_baseline(test_file):
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
analyzer = SentimentIntensityAnalyzer()
labels, preds = [], []
with open(test_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
compound = analyzer.polarity_scores(row["text"])["compound"]
if compound >= 0.05:
pred = "positive"
elif compound <= -0.05:
pred = "negative"
else:
pred = "neutral"
labels.append(LABEL2ID[row["label"]])
preds.append(LABEL2ID[pred])
return {
"macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
"accuracy": accuracy_score(labels, preds),
}
def run_test_evaluation(args, tokenizer, device, batch_size):
print("\nLoading best checkpoint for test evaluation...")
model = load_checkpoint(args.output_dir, device)
test_ds = AspectDataset(args.test_file, tokenizer)
test_loader = DataLoader(test_ds, batch_size=batch_size)
metrics = evaluate(model, test_loader, device)
labels_present = sorted(set(metrics["labels"]) | set(metrics["preds"]))
per_class_f1 = f1_score(metrics["labels"], metrics["preds"], average=None,
labels=[0, 1, 2], zero_division=0)
cm = confusion_matrix(metrics["labels"], metrics["preds"], labels=[0, 1, 2]).tolist()
report = classification_report(
metrics["labels"], metrics["preds"],
labels=[0, 1, 2], target_names=[ID2LABEL[i] for i in range(3)],
output_dict=True, zero_division=0,
)
results = {
"macro_f1": metrics["f1"],
"accuracy": metrics["accuracy"],
"per_class_f1": {ID2LABEL[i]: float(per_class_f1[i]) for i in range(3)},
"confusion_matrix": cm,
"confusion_matrix_labels": [ID2LABEL[i] for i in range(3)],
"classification_report": report,
}
try:
results["vader_baseline"] = run_vader_baseline(args.test_file)
except ImportError:
print("vaderSentiment not installed; skipping VADER baseline comparison.")
os.makedirs("results", exist_ok=True)
with open("results/test_metrics.json", "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"\nTest macro F1: {results['macro_f1']:.4f}")
print(f"Test accuracy: {results['accuracy']:.4f}")
print(f"Per-class F1: {results['per_class_f1']}")
if "vader_baseline" in results:
print(f"VADER baseline macro F1: {results['vader_baseline']['macro_f1']:.4f} "
f"(AspectBERT vs VADER on the same test set)")
print("Saved detailed results to results/test_metrics.json")
if labels_present:
pass # labels_present kept for potential debugging/inspection
def train(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size or (32 if device.type == "cuda" else 16)
print(f"Device: {device}, batch size: {batch_size}")
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)
train_ds = AspectDataset(args.train_file, tokenizer)
val_ds = AspectDataset(args.val_file, tokenizer)
print(f"Train examples: {len(train_ds)}, Val examples: {len(val_ds)}")
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
model = AspectBERT().to(device)
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(trainable_params, lr=2e-5, weight_decay=0.01)
total_steps = max(1, len(train_loader) * args.epochs)
scheduler = OneCycleLR(
optimizer,
max_lr=2e-5,
total_steps=total_steps,
pct_start=0.1,
anneal_strategy="cos",
)
criterion = torch.nn.CrossEntropyLoss()
history = []
best_f1 = -1.0
os.makedirs(os.path.dirname(args.history_file) or ".", exist_ok=True)
for epoch in range(1, args.epochs + 1):
model.train()
epoch_loss = 0.0
for batch in train_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()
logits = model(input_ids, attention_mask)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
scheduler.step()
epoch_loss += loss.item() * labels.size(0)
train_loss = epoch_loss / len(train_loader.dataset)
val_metrics = evaluate(model, val_loader, device, criterion)
print(f"Epoch {epoch}/{args.epochs} - "
f"train_loss: {train_loss:.4f} - "
f"val_loss: {val_metrics['loss']:.4f} - "
f"val_f1: {val_metrics['f1']:.4f} - "
f"val_acc: {val_metrics['accuracy']:.4f}")
history.append({
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_metrics["loss"],
"val_f1": val_metrics["f1"],
"val_accuracy": val_metrics["accuracy"],
"lr": scheduler.get_last_lr()[0],
})
if val_metrics["f1"] > best_f1:
best_f1 = val_metrics["f1"]
save_checkpoint(model, tokenizer, args.output_dir)
print(f" -> New best val F1: {best_f1:.4f}, checkpoint saved to {args.output_dir}")
with open(args.history_file, "w", encoding="utf-8") as f:
json.dump(history, f, indent=2)
print(f"\nSaved training history to {args.history_file}")
if args.test_file and os.path.exists(args.test_file) and best_f1 >= 0:
run_test_evaluation(args, tokenizer, device, batch_size)
def parse_args():
parser = argparse.ArgumentParser(description="Train AspectBERT.")
parser.add_argument("--train_file", default="data/train.jsonl")
parser.add_argument("--val_file", default="data/val.jsonl")
parser.add_argument("--test_file", default="data/test.jsonl")
parser.add_argument("--output_dir", default="models/aspectbert")
parser.add_argument("--history_file", default="results/training_history.json")
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=None,
help="Defaults to 32 on GPU, 16 on CPU.")
return parser.parse_args()
if __name__ == "__main__":
train(parse_args())