| """ |
| Training script for the text classifier model. |
| |
| Trains a DistilBERT-based classifier on the preprocessed Suicide-Watch dataset. |
| |
| Usage: |
| python train_text_model.py [options] |
| |
| Options: |
| --epochs: Number of training epochs (default: 3) |
| --batch-size: Batch size (default: 32) |
| --lr: Learning rate (default: 2e-5) |
| --max-length: Max token length (default: 256) |
| --model-name: Base model name (default: distilbert-base-uncased) |
| --subset: Use only N samples per class for fast iteration (default: 0 = all) |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| from pathlib import Path |
|
|
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from sklearn.metrics import ( |
| accuracy_score, |
| classification_report, |
| confusion_matrix, |
| precision_recall_fscore_support, |
| roc_auc_score, |
| ) |
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
| from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TextClassifier(nn.Module): |
| """DistilBERT-based text classifier.""" |
|
|
| def __init__(self, num_classes: int = 2, model_name: str = "distilbert-base-uncased"): |
| super().__init__() |
| self.encoder = AutoModel.from_pretrained(model_name) |
| self.dropout = nn.Dropout(0.3) |
| self.classifier = nn.Linear(768, num_classes) |
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| pooled = outputs.last_hidden_state[:, 0] |
| dropped = self.dropout(pooled) |
| logits = self.classifier(dropped) |
| return logits |
|
|
|
|
| class TextDataset(Dataset): |
| """Dataset for text classification.""" |
|
|
| def __init__(self, texts: list, labels: list, tokenizer, max_length: int = 256): |
| self.texts = texts |
| self.labels = labels |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __len__(self): |
| return len(self.texts) |
|
|
| def __getitem__(self, idx): |
| text = self.texts[idx] |
| label = self.labels[idx] |
|
|
| encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_tensors="pt") |
|
|
| return { |
| "input_ids": encoding["input_ids"].squeeze(), |
| "attention_mask": encoding["attention_mask"].squeeze(), |
| "label": torch.tensor(label, dtype=torch.long), |
| } |
|
|
|
|
| def collate_fn(batch): |
| """Dynamic padding — pad to longest sequence in batch, not max_length.""" |
| input_ids = [item["input_ids"] for item in batch] |
| attention_masks = [item["attention_mask"] for item in batch] |
| labels = torch.stack([item["label"] for item in batch]) |
|
|
| input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0) |
| attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0) |
|
|
| return {"input_ids": input_ids, "attention_mask": attention_masks, "label": labels} |
|
|
|
|
| def train_epoch(model, dataloader, optimizer, scheduler, criterion, device): |
| """Train for one epoch.""" |
| model.train() |
| total_loss = 0 |
| all_preds = [] |
| all_labels = [] |
|
|
| progress_bar = tqdm(dataloader, desc="Training") |
|
|
| for batch in progress_bar: |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["label"].to(device) |
|
|
| optimizer.zero_grad() |
|
|
| logits = model(input_ids, attention_mask) |
| loss = criterion(logits, labels) |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
|
|
| total_loss += loss.item() |
| preds = torch.argmax(logits, dim=1) |
| all_preds.extend(preds.cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
|
|
| progress_bar.set_postfix({"loss": loss.item()}) |
|
|
| avg_loss = total_loss / len(dataloader) |
| accuracy = accuracy_score(all_labels, all_preds) |
|
|
| return avg_loss, accuracy |
|
|
|
|
| def evaluate(model, dataloader, criterion, device): |
| """Evaluate the model.""" |
| model.eval() |
| total_loss = 0 |
| all_preds = [] |
| all_labels = [] |
| all_probs = [] |
|
|
| with torch.no_grad(): |
| for batch in tqdm(dataloader, desc="Evaluating"): |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["label"].to(device) |
|
|
| logits = model(input_ids, attention_mask) |
| loss = criterion(logits, labels) |
|
|
| total_loss += loss.item() |
| probs = torch.softmax(logits, dim=1) |
| preds = torch.argmax(probs, dim=1) |
|
|
| all_preds.extend(preds.cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
| all_probs.extend(probs[:, 1].cpu().numpy()) |
|
|
| avg_loss = total_loss / len(dataloader) |
|
|
| |
| accuracy = accuracy_score(all_labels, all_preds) |
| precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="binary") |
|
|
| try: |
| roc_auc = roc_auc_score(all_labels, all_probs) |
| except ValueError: |
| roc_auc = 0.0 |
|
|
| metrics = { |
| "loss": avg_loss, |
| "accuracy": accuracy, |
| "precision": precision, |
| "recall": recall, |
| "f1": f1, |
| "roc_auc": roc_auc, |
| } |
|
|
| return metrics, all_preds, all_labels, all_probs |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train text classifier") |
| parser.add_argument("--epochs", type=int, default=3) |
| parser.add_argument("--batch-size", type=int, default=32) |
| parser.add_argument("--lr", type=float, default=2e-5) |
| parser.add_argument("--model-name", type=str, default="distilbert-base-uncased") |
| parser.add_argument("--max-length", type=int, default=256) |
| parser.add_argument("--data-dir", type=str, default=None) |
| parser.add_argument("--output-dir", type=str, default=None) |
| parser.add_argument( |
| "--subset", type=int, default=0, help="Use N samples per class for fast iteration (0 = all data)" |
| ) |
| args = parser.parse_args() |
|
|
| |
| base_dir = Path(__file__).parent.parent |
| data_dir = Path(args.data_dir) if args.data_dir else base_dir / "data" / "suicide_watch" / "processed" |
| output_dir = Path(args.output_dir) if args.output_dir else base_dir / "models" |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| if torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| else: |
| device = torch.device("cpu") |
| logger.info(f"Using device: {device}") |
|
|
| |
| logger.info("Loading data...") |
| train_df = pd.read_csv(data_dir / "train.csv") |
| val_df = pd.read_csv(data_dir / "val.csv") |
| test_df = pd.read_csv(data_dir / "test.csv") |
|
|
| |
| if args.subset > 0: |
| logger.info(f"Subsetting to {args.subset} samples per class...") |
| train_dfs = [g.sample(n=min(args.subset, len(g)), random_state=42) for _, g in train_df.groupby("label_id")] |
| train_df = pd.concat(train_dfs).reset_index(drop=True) |
| val_dfs = [g.sample(n=min(args.subset // 4, len(g)), random_state=42) for _, g in val_df.groupby("label_id")] |
| val_df = pd.concat(val_dfs).reset_index(drop=True) |
| test_dfs = [g.sample(n=min(args.subset // 4, len(g)), random_state=42) for _, g in test_df.groupby("label_id")] |
| test_df = pd.concat(test_dfs).reset_index(drop=True) |
|
|
| logger.info(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}") |
|
|
| |
| logger.info(f"Loading tokenizer: {args.model_name}") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
| |
| train_dataset = TextDataset( |
| train_df["clean_text"].tolist(), train_df["label_id"].tolist(), tokenizer, args.max_length |
| ) |
| val_dataset = TextDataset(val_df["clean_text"].tolist(), val_df["label_id"].tolist(), tokenizer, args.max_length) |
| test_dataset = TextDataset(test_df["clean_text"].tolist(), test_df["label_id"].tolist(), tokenizer, args.max_length) |
|
|
| |
| num_workers = 0 if device.type == "mps" else 2 |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| num_workers=num_workers, |
| pin_memory=False, |
| ) |
| val_loader = DataLoader( |
| val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=num_workers, pin_memory=False |
| ) |
| test_loader = DataLoader( |
| test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=num_workers, pin_memory=False |
| ) |
|
|
| |
| logger.info("Creating model...") |
| num_classes = len(train_df["label_id"].unique()) |
| model = TextClassifier(num_classes=num_classes, model_name=args.model_name) |
| model.to(device) |
|
|
| |
| criterion = nn.CrossEntropyLoss() |
| optimizer = AdamW(model.parameters(), lr=args.lr) |
|
|
| total_steps = len(train_loader) * args.epochs |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, num_warmup_steps=total_steps // 10, num_training_steps=total_steps |
| ) |
|
|
| |
| logger.info("Starting training...") |
| best_val_f1 = 0 |
| training_history = [] |
|
|
| for epoch in range(args.epochs): |
| logger.info(f"\nEpoch {epoch + 1}/{args.epochs}") |
|
|
| |
| train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device) |
| logger.info(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}") |
|
|
| |
| val_metrics, _, _, _ = evaluate(model, val_loader, criterion, device) |
| logger.info(f"Val Loss: {val_metrics['loss']:.4f}, Val F1: {val_metrics['f1']:.4f}") |
|
|
| training_history.append( |
| { |
| "epoch": epoch + 1, |
| "train_loss": train_loss, |
| "train_acc": train_acc, |
| "val_loss": val_metrics["loss"], |
| "val_f1": val_metrics["f1"], |
| "val_roc_auc": val_metrics["roc_auc"], |
| } |
| ) |
|
|
| |
| if val_metrics["f1"] > best_val_f1: |
| best_val_f1 = val_metrics["f1"] |
| torch.save(model.state_dict(), output_dir / "text_classifier.pt") |
| logger.info(f"Saved best model with F1: {best_val_f1:.4f}") |
|
|
| |
| logger.info("\nEvaluating on test set...") |
| model.load_state_dict(torch.load(output_dir / "text_classifier.pt", map_location=device)) |
| test_metrics, test_preds, test_labels, test_probs = evaluate(model, test_loader, criterion, device) |
|
|
| logger.info("\nTest Results:") |
| logger.info(f" Accuracy: {test_metrics['accuracy']:.4f}") |
| logger.info(f" Precision: {test_metrics['precision']:.4f}") |
| logger.info(f" Recall: {test_metrics['recall']:.4f}") |
| logger.info(f" F1 Score: {test_metrics['f1']:.4f}") |
| logger.info(f" ROC-AUC: {test_metrics['roc_auc']:.4f}") |
|
|
| |
| print("\nClassification Report:") |
| print(classification_report(test_labels, test_preds, target_names=["low_risk", "high_risk"])) |
|
|
| |
| print("\nConfusion Matrix:") |
| print(confusion_matrix(test_labels, test_preds)) |
|
|
| |
| results = { |
| "model_name": args.model_name, |
| "epochs": args.epochs, |
| "batch_size": args.batch_size, |
| "learning_rate": args.lr, |
| "best_val_f1": best_val_f1, |
| "test_metrics": test_metrics, |
| "training_history": training_history, |
| "label_map": {"low_risk": 0, "high_risk": 1}, |
| } |
|
|
| with open(output_dir / "training_results.json", "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| logger.info(f"\nModel saved to: {output_dir / 'text_classifier.pt'}") |
| logger.info(f"Results saved to: {output_dir / 'training_results.json'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|