Spaces:
Sleeping
Sleeping
stefhooy
Updating the Data Model : IBM Debater vs CMV, testing out CVM and check if it's good or not
94abcb8 | from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from collections import Counter | |
| from pathlib import Path | |
| from data import load_ibm, clean_debates | |
| NUM_LABELS = 4 | |
| def _class_weights(dataset, device): | |
| """Inverse-frequency weights so rare labels aren't ignored.""" | |
| import torch | |
| counts = Counter(ex["label"] for ex in dataset.examples) | |
| total = sum(counts.values()) | |
| weights = torch.ones(NUM_LABELS) | |
| for label_id, count in counts.items(): | |
| weights[label_id] = total / (NUM_LABELS * count) | |
| return weights.to(device) | |
| def evaluate(model, loader, device): | |
| import torch | |
| from sklearn.metrics import f1_score, accuracy_score | |
| model.eval() | |
| all_preds, all_labels = [], [] | |
| total_loss = 0.0 | |
| loss_fn = torch.nn.CrossEntropyLoss() | |
| with torch.no_grad(): | |
| for batch in loader: | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["label"].to(device) | |
| logits = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ).logits | |
| total_loss += loss_fn(logits, labels).item() | |
| all_preds.extend(logits.argmax(dim=-1).cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| return ( | |
| total_loss / len(loader), | |
| accuracy_score(all_labels, all_preds), | |
| f1_score(all_labels, all_preds, average="macro", zero_division=0), | |
| ) | |
| def train( | |
| epochs: int = 3, | |
| batch_size: int = 16, | |
| lr: float = 2e-5, | |
| max_length: int = 256, | |
| val_split: float = 0.15, | |
| output_dir: str = "models", | |
| ) -> str: | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| # --- load data first, before touching CUDA --- | |
| print("Loading data...") | |
| debates = clean_debates(load_ibm("train")) | |
| print(f" IBM Debater: {len(debates)} debates") | |
| try: | |
| from data import load_cmv | |
| cmv = clean_debates(load_cmv("train")) | |
| debates += cmv | |
| print(f" CMV (file): {len(cmv)} debates (total: {len(debates)})") | |
| except FileNotFoundError: | |
| try: | |
| from data import load_convokit_cmv | |
| print(" CMV file missing — loading ConvoKit corpus…") | |
| cmv = clean_debates(load_convokit_cmv("train")) | |
| if cmv: | |
| debates += cmv | |
| print( | |
| f" CMV (ConvoKit): {len(cmv)} debates" | |
| f" (total: {len(debates)})" | |
| ) | |
| else: | |
| print(" ConvoKit returned 0 debates — IBM only") | |
| except Exception as e2: | |
| print( | |
| f" CMV not available" | |
| f" ({e2.__class__.__name__}: {e2})" | |
| f" — training on IBM only" | |
| ) | |
| except Exception as e: | |
| print(f" CMV load error ({e}) — training on IBM only") | |
| # Heavy imports after data loading so torch/CUDA init | |
| # doesn't compete with dataset memory usage | |
| import torch | |
| from torch.utils.data import DataLoader, random_split | |
| from torch.optim import AdamW | |
| from transformers import ( | |
| RobertaTokenizerFast, | |
| get_linear_schedule_with_warmup, | |
| ) | |
| from tqdm import tqdm | |
| from .dataset import ArgumentDataset | |
| from .model import MODEL_NAME, build_model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {device}") | |
| tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_NAME) | |
| dataset = ArgumentDataset(debates, tokenizer, max_length) | |
| print(f" {len(dataset)} total examples") | |
| val_size = int(len(dataset) * val_split) | |
| train_set, val_set = random_split( | |
| dataset, | |
| [len(dataset) - val_size, val_size], | |
| generator=torch.Generator().manual_seed(42), | |
| ) | |
| train_loader = DataLoader( | |
| train_set, batch_size=batch_size, shuffle=True, num_workers=0, | |
| ) | |
| val_loader = DataLoader( | |
| val_set, batch_size=batch_size, shuffle=False, num_workers=0, | |
| ) | |
| model = build_model().to(device) | |
| weights = _class_weights(dataset, device) | |
| loss_fn = torch.nn.CrossEntropyLoss(weight=weights) | |
| optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.01) | |
| total_steps = len(train_loader) * epochs | |
| scheduler = get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=int(0.1 * total_steps), | |
| num_training_steps=total_steps, | |
| ) | |
| best_f1 = 0.0 | |
| best_ckpt = os.path.join(output_dir, "best") | |
| for epoch in range(1, epochs + 1): | |
| model.train() | |
| total_loss = 0.0 | |
| for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"): | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["label"].to(device) | |
| optimizer.zero_grad() | |
| logits = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ).logits | |
| loss = loss_fn(logits, labels) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| total_loss += loss.item() | |
| val_loss, val_acc, val_f1 = evaluate(model, val_loader, device) | |
| print( | |
| f"Epoch {epoch}: " | |
| f"train_loss={total_loss / len(train_loader):.4f} " | |
| f"val_loss={val_loss:.4f} " | |
| f"val_acc={val_acc:.4f} " | |
| f"val_f1={val_f1:.4f}" | |
| ) | |
| if val_f1 > best_f1: | |
| best_f1 = val_f1 | |
| model.save_pretrained(best_ckpt) | |
| tokenizer.save_pretrained(best_ckpt) | |
| print( | |
| f" → Saved best checkpoint " | |
| f"(F1={best_f1:.4f}) to {best_ckpt}/" | |
| ) | |
| with open(os.path.join(output_dir, "train_config.json"), "w") as f: | |
| json.dump( | |
| { | |
| "epochs": epochs, | |
| "batch_size": batch_size, | |
| "lr": lr, | |
| "max_length": max_length, | |
| "best_val_f1": best_f1, | |
| }, | |
| f, | |
| indent=2, | |
| ) | |
| print(f"\nTraining complete. Best val F1: {best_f1:.4f}") | |
| return best_ckpt | |
| if __name__ == "__main__": | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--epochs", type=int, default=3) | |
| p.add_argument("--batch-size", type=int, default=16) | |
| p.add_argument("--lr", type=float, default=2e-5) | |
| p.add_argument("--max-length", type=int, default=256) | |
| p.add_argument("--output-dir", type=str, default="models") | |
| args = p.parse_args() | |
| train( | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| lr=args.lr, | |
| max_length=args.max_length, | |
| output_dir=args.output_dir, | |
| ) | |