stefhooy
Updating the Data Model : IBM Debater vs CMV, testing out CVM and check if it's good or not
94abcb8
Raw
History Blame Contribute Delete
6.83 kB
from __future__ import annotations
import argparse
import json
import os
from collections import Counter
from pathlib import Path
from data import load_ibm, clean_debates
NUM_LABELS = 4
def _class_weights(dataset, device):
"""Inverse-frequency weights so rare labels aren't ignored."""
import torch
counts = Counter(ex["label"] for ex in dataset.examples)
total = sum(counts.values())
weights = torch.ones(NUM_LABELS)
for label_id, count in counts.items():
weights[label_id] = total / (NUM_LABELS * count)
return weights.to(device)
def evaluate(model, loader, device):
import torch
from sklearn.metrics import f1_score, accuracy_score
model.eval()
all_preds, all_labels = [], []
total_loss = 0.0
loss_fn = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
).logits
total_loss += loss_fn(logits, labels).item()
all_preds.extend(logits.argmax(dim=-1).cpu().numpy())
all_labels.extend(labels.cpu().numpy())
return (
total_loss / len(loader),
accuracy_score(all_labels, all_preds),
f1_score(all_labels, all_preds, average="macro", zero_division=0),
)
def train(
epochs: int = 3,
batch_size: int = 16,
lr: float = 2e-5,
max_length: int = 256,
val_split: float = 0.15,
output_dir: str = "models",
) -> str:
Path(output_dir).mkdir(parents=True, exist_ok=True)
# --- load data first, before touching CUDA ---
print("Loading data...")
debates = clean_debates(load_ibm("train"))
print(f" IBM Debater: {len(debates)} debates")
try:
from data import load_cmv
cmv = clean_debates(load_cmv("train"))
debates += cmv
print(f" CMV (file): {len(cmv)} debates (total: {len(debates)})")
except FileNotFoundError:
try:
from data import load_convokit_cmv
print(" CMV file missing — loading ConvoKit corpus…")
cmv = clean_debates(load_convokit_cmv("train"))
if cmv:
debates += cmv
print(
f" CMV (ConvoKit): {len(cmv)} debates"
f" (total: {len(debates)})"
)
else:
print(" ConvoKit returned 0 debates — IBM only")
except Exception as e2:
print(
f" CMV not available"
f" ({e2.__class__.__name__}: {e2})"
f" — training on IBM only"
)
except Exception as e:
print(f" CMV load error ({e}) — training on IBM only")
# Heavy imports after data loading so torch/CUDA init
# doesn't compete with dataset memory usage
import torch
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from transformers import (
RobertaTokenizerFast,
get_linear_schedule_with_warmup,
)
from tqdm import tqdm
from .dataset import ArgumentDataset
from .model import MODEL_NAME, build_model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_NAME)
dataset = ArgumentDataset(debates, tokenizer, max_length)
print(f" {len(dataset)} total examples")
val_size = int(len(dataset) * val_split)
train_set, val_set = random_split(
dataset,
[len(dataset) - val_size, val_size],
generator=torch.Generator().manual_seed(42),
)
train_loader = DataLoader(
train_set, batch_size=batch_size, shuffle=True, num_workers=0,
)
val_loader = DataLoader(
val_set, batch_size=batch_size, shuffle=False, num_workers=0,
)
model = build_model().to(device)
weights = _class_weights(dataset, device)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.01)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps,
)
best_f1 = 0.0
best_ckpt = os.path.join(output_dir, "best")
for epoch in range(1, epochs + 1):
model.train()
total_loss = 0.0
for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
).logits
loss = loss_fn(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
val_loss, val_acc, val_f1 = evaluate(model, val_loader, device)
print(
f"Epoch {epoch}: "
f"train_loss={total_loss / len(train_loader):.4f} "
f"val_loss={val_loss:.4f} "
f"val_acc={val_acc:.4f} "
f"val_f1={val_f1:.4f}"
)
if val_f1 > best_f1:
best_f1 = val_f1
model.save_pretrained(best_ckpt)
tokenizer.save_pretrained(best_ckpt)
print(
f" → Saved best checkpoint "
f"(F1={best_f1:.4f}) to {best_ckpt}/"
)
with open(os.path.join(output_dir, "train_config.json"), "w") as f:
json.dump(
{
"epochs": epochs,
"batch_size": batch_size,
"lr": lr,
"max_length": max_length,
"best_val_f1": best_f1,
},
f,
indent=2,
)
print(f"\nTraining complete. Best val F1: {best_f1:.4f}")
return best_ckpt
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--epochs", type=int, default=3)
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--lr", type=float, default=2e-5)
p.add_argument("--max-length", type=int, default=256)
p.add_argument("--output-dir", type=str, default="models")
args = p.parse_args()
train(
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
max_length=args.max_length,
output_dir=args.output_dir,
)