""" Train DISCO model using PyTorch end-to-end training. This script trains the CLIP-based classifier directly in PyTorch, avoiding the sklearn intermediate step. """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import numpy as np import json from pathlib import Path from sklearn.metrics import ( roc_auc_score, average_precision_score, roc_curve, classification_report ) from transformers import CLIPProcessor from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn from src.dataset import get_dataset, ImageDataset from src.model import DISCO, DISCOConfig def tune_threshold(y_true: np.ndarray, y_scores: np.ndarray, metric: str = "f1") -> tuple[float, dict]: """ Tune classification threshold on validation set. Args: y_true: Ground truth binary labels y_scores: Predicted probability scores metric: Metric to optimize ("f1", "precision", "recall", "balanced_accuracy") Returns: Best threshold and metrics at that threshold """ fpr, tpr, thresholds = roc_curve(y_true, y_scores) best_threshold = 0.5 best_score = 0.0 best_metrics = {} for threshold in thresholds: y_pred = (y_scores >= threshold).astype(int) # Compute metrics tp = np.sum((y_pred == 1) & (y_true == 1)) fp = np.sum((y_pred == 1) & (y_true == 0)) fn = np.sum((y_pred == 0) & (y_true == 1)) tn = np.sum((y_pred == 0) & (y_true == 0)) precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 balanced_accuracy = (tpr[np.argmax(thresholds >= threshold)] + (1 - fpr[np.argmax(thresholds >= threshold)])) / 2 score_map = { "f1": f1, "precision": precision, "recall": recall, "balanced_accuracy": balanced_accuracy } score = score_map.get(metric, f1) if score > best_score: best_score = score best_threshold = threshold best_metrics = { "threshold": threshold, "precision": precision, "recall": recall, "f1": f1, "balanced_accuracy": balanced_accuracy, "tp": int(tp), "fp": int(fp), "tn": int(tn), "fn": int(fn) } return best_threshold, best_metrics def train_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, device: str) -> float: """Train for one epoch.""" model.train() total_loss = 0.0 num_batches = 0 for inputs, labels in dataloader: pixel_values = inputs["pixel_values"].to(device) labels = labels.to(device) # Forward pass optimizer.zero_grad() logits = model(pixel_values=pixel_values) loss = criterion(logits, labels) # Backward pass loss.backward() optimizer.step() total_loss += loss.item() num_batches += 1 return total_loss / num_batches if num_batches > 0 else 0.0 def evaluate(model: nn.Module, dataloader: DataLoader, device: str) -> tuple[np.ndarray, np.ndarray]: """Evaluate model and return predictions and labels.""" model.eval() all_proba = [] all_labels = [] with torch.no_grad(): for inputs, labels in dataloader: pixel_values = inputs["pixel_values"].to(device) labels = labels.to(device) # Get predictions proba = model.predict_proba(pixel_values) all_proba.append(proba.cpu().numpy()) all_labels.append(labels.cpu().numpy()) proba = np.vstack(all_proba) labels = np.concatenate(all_labels) return proba, labels def train( num_epochs: int = 10, batch_size: int = 32, learning_rate: float = 1e-3, weight_decay: float = 1e-4, class_weight: str = "balanced" ): """ Train DISCO model using PyTorch. Args: num_epochs: Number of training epochs batch_size: Batch size for training learning_rate: Learning rate for optimizer weight_decay: Weight decay (L2 regularization) class_weight: Class weighting strategy ("balanced" or None) """ print("=" * 60) print("DISCO Model Training (PyTorch)") print("=" * 60) # Setup device device = "mps" if torch.backends.mps.is_available() else ( "cuda" if torch.cuda.is_available() else "cpu") print(f"\nUsing device: {device}") # Load dataset splits print("\n[1/6] Loading dataset splits...") dataset = get_dataset() train_paths = [str(Path(img_path)) for img_path in dataset["train"]["image"]] val_paths = [str(Path(img_path)) for img_path in dataset["val"]["image"]] test_paths = [str(Path(img_path)) for img_path in dataset["test"]["image"]] train_labels = np.array(dataset["train"]["label"]) val_labels = np.array(dataset["val"]["label"]) test_labels = np.array(dataset["test"]["label"]) print(f" Train: {len(train_paths)} images") print(f" Val: {len(val_paths)} images") print(f" Test: {len(test_paths)} images") # Load CLIP processor print("\n[2/6] Loading CLIP processor...") model_name = "openai/clip-vit-base-patch32" processor = CLIPProcessor.from_pretrained(model_name) print(f" Model: {model_name}") # Create datasets and dataloaders print("\n[3/6] Creating datasets and dataloaders...") train_dataset = ImageDataset(train_paths, train_labels, processor) val_dataset = ImageDataset(val_paths, val_labels, processor) test_dataset = ImageDataset(test_paths, test_labels, processor) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=0) # Initialize model print("\n[4/6] Initializing model...") config = DISCOConfig( clip_model_name=model_name, num_classes=2, threshold=0.5 ) model = DISCO(config).to(device) # Only train the classifier, keep CLIP frozen optimizer = optim.AdamW( model.classifier.parameters(), lr=learning_rate, weight_decay=weight_decay ) # Setup loss function with class weights if needed if class_weight == "balanced": # Compute class weights from training data class_counts = np.bincount(train_labels) total = len(train_labels) class_weights = torch.tensor([ total / (2 * class_counts[0]), total / (2 * class_counts[1]) ], dtype=torch.float32).to(device) criterion = nn.CrossEntropyLoss(weight=class_weights) print(f" Using balanced class weights: {class_weights.cpu().numpy()}") else: criterion = nn.CrossEntropyLoss() print(" Using uniform class weights") print( f" Trainable parameters: {sum(p.numel() for p in model.classifier.parameters() if p.requires_grad):,}") # Training loop print("\n[5/6] Training model...") best_val_f1 = 0.0 best_model_state = None with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeElapsedColumn(), console=None, ) as progress: train_task = progress.add_task("Training", total=num_epochs) for epoch in range(num_epochs): # Train train_loss = train_epoch( model, train_loader, criterion, optimizer, device) # Validate val_proba, val_labels_np = evaluate(model, val_loader, device) val_scores = val_proba[:, 1] val_roc_auc = roc_auc_score(val_labels_np, val_scores) # Compute F1 at default threshold val_pred = (val_scores >= 0.5).astype(int) tp = np.sum((val_pred == 1) & (val_labels_np == 1)) fp = np.sum((val_pred == 1) & (val_labels_np == 0)) fn = np.sum((val_pred == 0) & (val_labels_np == 1)) precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 val_f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 progress.update(train_task, advance=1, description=f"Epoch {epoch+1}/{num_epochs} | Loss: {train_loss:.4f} | " f"Val ROC-AUC: {val_roc_auc:.4f} | Val F1: {val_f1:.4f}") # Save best model if val_f1 > best_val_f1: best_val_f1 = val_f1 best_model_state = model.state_dict().copy() # Load best model if best_model_state is not None: model.load_state_dict(best_model_state) print(f"\n Best validation F1: {best_val_f1:.4f}") # Tune threshold on validation set print("\n[6/6] Tuning threshold on validation set...") val_proba, val_labels_np = evaluate(model, val_loader, device) val_scores = val_proba[:, 1] best_threshold, threshold_metrics = tune_threshold( val_labels_np, val_scores, metric="f1") print(f" Best threshold: {best_threshold:.4f}") print(" Validation metrics at best threshold:") print(f" Precision: {threshold_metrics['precision']:.4f}") print(f" Recall: {threshold_metrics['recall']:.4f}") print(f" F1: {threshold_metrics['f1']:.4f}") print( f" Balanced Accuracy: {threshold_metrics['balanced_accuracy']:.4f}") # Update model threshold model.threshold = best_threshold config.threshold = best_threshold # Evaluate on test set print("\n" + "=" * 60) print("Test Set Evaluation") print("=" * 60) test_proba, test_labels_np = evaluate(model, test_loader, device) test_scores = test_proba[:, 1] test_roc_auc = roc_auc_score(test_labels_np, test_scores) test_pr_auc = average_precision_score(test_labels_np, test_scores) print("\nTest Set Metrics (probability scores):") print(f" ROC AUC: {test_roc_auc:.4f}") print(f" PR AUC: {test_pr_auc:.4f}") # Apply best threshold test_pred = (test_scores >= best_threshold).astype(int) print(f"\nTest Set Metrics (with threshold={best_threshold:.4f}):") print(classification_report(test_labels_np, test_pred, target_names=["FAMILY_SAFE/UNCERTAIN", "SUGGESTIVE"])) # Confusion matrix tp = np.sum((test_pred == 1) & (test_labels_np == 1)) fp = np.sum((test_pred == 1) & (test_labels_np == 0)) tn = np.sum((test_pred == 0) & (test_labels_np == 0)) fn = np.sum((test_pred == 0) & (test_labels_np == 1)) print("\nConfusion Matrix:") print(" Predicted") print(" FAMILY_SAFE SUGGESTIVE") print(f"Actual FAMILY_SAFE {tn:4d} {fp:4d}") print(f" SUGGESTIVE {fn:4d} {tp:4d}") # Save model and metadata print("\n" + "=" * 60) print("Saving Model") print("=" * 60) models_dir = Path(__file__).parent.parent / "models" models_dir.mkdir(exist_ok=True) # Save Hugging Face model config.save_pretrained(models_dir) model.save_pretrained(models_dir) print(f" Saved Hugging Face model to: {models_dir}") # Save processor processor.save_pretrained(models_dir) print(f" Saved processor to: {models_dir}") # Save metadata metadata = { "model_name": model_name, "threshold": float(best_threshold), "test_roc_auc": float(test_roc_auc), "test_pr_auc": float(test_pr_auc), "val_roc_auc": float(roc_auc_score(val_labels_np, val_scores)), "val_pr_auc": float(average_precision_score(val_labels_np, val_scores)), "threshold_metrics": { k: float(v) if isinstance(v, (int, float, np.number)) else v for k, v in threshold_metrics.items() }, "embedding_dim": int(model.clip_model.config.projection_dim), "model_type": "clip_nsfw_detector", "framework": "pytorch", "training_config": { "num_epochs": num_epochs, "batch_size": batch_size, "learning_rate": learning_rate, "weight_decay": weight_decay, "class_weight": class_weight } } metadata_path = models_dir / "model_metadata.json" with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2) print(f" Saved metadata to: {metadata_path}") print("\nModel saved successfully!") print(f"\nModel is ready for Hugging Face upload from: {models_dir}") return { "model": model, "threshold": best_threshold, "test_roc_auc": test_roc_auc, "test_pr_auc": test_pr_auc, "threshold_metrics": threshold_metrics, "metadata_path": metadata_path } if __name__ == "__main__": results = train()