depscreen / ml /scripts /deprecated /train_text_model.py
halsabbah's picture
Add CI/CD pipelines and code quality tooling
3187428
"""
Training script for the text classifier model.
Trains a DistilBERT-based classifier on the preprocessed Suicide-Watch dataset.
Usage:
python train_text_model.py [options]
Options:
--epochs: Number of training epochs (default: 3)
--batch-size: Batch size (default: 32)
--lr: Learning rate (default: 2e-5)
--max-length: Max token length (default: 256)
--model-name: Base model name (default: distilbert-base-uncased)
--subset: Use only N samples per class for fast iteration (default: 0 = all)
"""
import argparse
import json
import logging
from pathlib import Path
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
precision_recall_fscore_support,
roc_auc_score,
)
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextClassifier(nn.Module):
"""DistilBERT-based text classifier."""
def __init__(self, num_classes: int = 2, model_name: str = "distilbert-base-uncased"):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(768, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state[:, 0] # CLS token
dropped = self.dropout(pooled)
logits = self.classifier(dropped)
return logits
class TextDataset(Dataset):
"""Dataset for text classification."""
def __init__(self, texts: list, labels: list, tokenizer, max_length: int = 256):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_tensors="pt")
return {
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"label": torch.tensor(label, dtype=torch.long),
}
def collate_fn(batch):
"""Dynamic padding — pad to longest sequence in batch, not max_length."""
input_ids = [item["input_ids"] for item in batch]
attention_masks = [item["attention_mask"] for item in batch]
labels = torch.stack([item["label"] for item in batch])
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
return {"input_ids": input_ids, "attention_mask": attention_masks, "label": labels}
def train_epoch(model, dataloader, optimizer, scheduler, criterion, device):
"""Train for one epoch."""
model.train()
total_loss = 0
all_preds = []
all_labels = []
progress_bar = tqdm(dataloader, desc="Training")
for batch in progress_bar:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
logits = model(input_ids, attention_mask)
loss = criterion(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
progress_bar.set_postfix({"loss": loss.item()})
avg_loss = total_loss / len(dataloader)
accuracy = accuracy_score(all_labels, all_preds)
return avg_loss, accuracy
def evaluate(model, dataloader, criterion, device):
"""Evaluate the model."""
model.eval()
total_loss = 0
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
logits = model(input_ids, attention_mask)
loss = criterion(logits, labels)
total_loss += loss.item()
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs[:, 1].cpu().numpy()) # Probability of positive class
avg_loss = total_loss / len(dataloader)
# Calculate metrics
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="binary")
try:
roc_auc = roc_auc_score(all_labels, all_probs)
except ValueError:
roc_auc = 0.0
metrics = {
"loss": avg_loss,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
"roc_auc": roc_auc,
}
return metrics, all_preds, all_labels, all_probs
def main():
parser = argparse.ArgumentParser(description="Train text classifier")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--model-name", type=str, default="distilbert-base-uncased")
parser.add_argument("--max-length", type=int, default=256)
parser.add_argument("--data-dir", type=str, default=None)
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument(
"--subset", type=int, default=0, help="Use N samples per class for fast iteration (0 = all data)"
)
args = parser.parse_args()
# Setup paths
base_dir = Path(__file__).parent.parent
data_dir = Path(args.data_dir) if args.data_dir else base_dir / "data" / "suicide_watch" / "processed"
output_dir = Path(args.output_dir) if args.output_dir else base_dir / "models"
output_dir.mkdir(parents=True, exist_ok=True)
# Setup device — prefer MPS (Apple Silicon GPU), then CUDA, then CPU
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
# Load data
logger.info("Loading data...")
train_df = pd.read_csv(data_dir / "train.csv")
val_df = pd.read_csv(data_dir / "val.csv")
test_df = pd.read_csv(data_dir / "test.csv")
# Subset sampling for fast iteration
if args.subset > 0:
logger.info(f"Subsetting to {args.subset} samples per class...")
train_dfs = [g.sample(n=min(args.subset, len(g)), random_state=42) for _, g in train_df.groupby("label_id")]
train_df = pd.concat(train_dfs).reset_index(drop=True)
val_dfs = [g.sample(n=min(args.subset // 4, len(g)), random_state=42) for _, g in val_df.groupby("label_id")]
val_df = pd.concat(val_dfs).reset_index(drop=True)
test_dfs = [g.sample(n=min(args.subset // 4, len(g)), random_state=42) for _, g in test_df.groupby("label_id")]
test_df = pd.concat(test_dfs).reset_index(drop=True)
logger.info(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
# Load tokenizer
logger.info(f"Loading tokenizer: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Create datasets
train_dataset = TextDataset(
train_df["clean_text"].tolist(), train_df["label_id"].tolist(), tokenizer, args.max_length
)
val_dataset = TextDataset(val_df["clean_text"].tolist(), val_df["label_id"].tolist(), tokenizer, args.max_length)
test_dataset = TextDataset(test_df["clean_text"].tolist(), test_df["label_id"].tolist(), tokenizer, args.max_length)
# Create dataloaders with dynamic padding and parallel workers
num_workers = 0 if device.type == "mps" else 2
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=False,
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=num_workers, pin_memory=False
)
test_loader = DataLoader(
test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=num_workers, pin_memory=False
)
# Create model
logger.info("Creating model...")
num_classes = len(train_df["label_id"].unique())
model = TextClassifier(num_classes=num_classes, model_name=args.model_name)
model.to(device)
# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=args.lr)
total_steps = len(train_loader) * args.epochs
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=total_steps // 10, num_training_steps=total_steps
)
# Training loop
logger.info("Starting training...")
best_val_f1 = 0
training_history = []
for epoch in range(args.epochs):
logger.info(f"\nEpoch {epoch + 1}/{args.epochs}")
# Train
train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device)
logger.info(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
# Validate
val_metrics, _, _, _ = evaluate(model, val_loader, criterion, device)
logger.info(f"Val Loss: {val_metrics['loss']:.4f}, Val F1: {val_metrics['f1']:.4f}")
training_history.append(
{
"epoch": epoch + 1,
"train_loss": train_loss,
"train_acc": train_acc,
"val_loss": val_metrics["loss"],
"val_f1": val_metrics["f1"],
"val_roc_auc": val_metrics["roc_auc"],
}
)
# Save best model
if val_metrics["f1"] > best_val_f1:
best_val_f1 = val_metrics["f1"]
torch.save(model.state_dict(), output_dir / "text_classifier.pt")
logger.info(f"Saved best model with F1: {best_val_f1:.4f}")
# Final evaluation on test set
logger.info("\nEvaluating on test set...")
model.load_state_dict(torch.load(output_dir / "text_classifier.pt", map_location=device))
test_metrics, test_preds, test_labels, test_probs = evaluate(model, test_loader, criterion, device)
logger.info("\nTest Results:")
logger.info(f" Accuracy: {test_metrics['accuracy']:.4f}")
logger.info(f" Precision: {test_metrics['precision']:.4f}")
logger.info(f" Recall: {test_metrics['recall']:.4f}")
logger.info(f" F1 Score: {test_metrics['f1']:.4f}")
logger.info(f" ROC-AUC: {test_metrics['roc_auc']:.4f}")
# Print classification report
print("\nClassification Report:")
print(classification_report(test_labels, test_preds, target_names=["low_risk", "high_risk"]))
# Print confusion matrix
print("\nConfusion Matrix:")
print(confusion_matrix(test_labels, test_preds))
# Save training results
results = {
"model_name": args.model_name,
"epochs": args.epochs,
"batch_size": args.batch_size,
"learning_rate": args.lr,
"best_val_f1": best_val_f1,
"test_metrics": test_metrics,
"training_history": training_history,
"label_map": {"low_risk": 0, "high_risk": 1},
}
with open(output_dir / "training_results.json", "w") as f:
json.dump(results, f, indent=2)
logger.info(f"\nModel saved to: {output_dir / 'text_classifier.pt'}")
logger.info(f"Results saved to: {output_dir / 'training_results.json'}")
if __name__ == "__main__":
main()