|
|
""" |
|
|
Train DISCO model using PyTorch end-to-end training. |
|
|
|
|
|
This script trains the CLIP-based classifier directly in PyTorch, |
|
|
avoiding the sklearn intermediate step. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader |
|
|
import numpy as np |
|
|
import json |
|
|
from pathlib import Path |
|
|
from sklearn.metrics import ( |
|
|
roc_auc_score, average_precision_score, roc_curve, classification_report |
|
|
) |
|
|
from transformers import CLIPProcessor |
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn |
|
|
from src.dataset import get_dataset, ImageDataset |
|
|
from src.model import DISCO, DISCOConfig |
|
|
|
|
|
|
|
|
def tune_threshold(y_true: np.ndarray, y_scores: np.ndarray, metric: str = "f1") -> tuple[float, dict]: |
|
|
""" |
|
|
Tune classification threshold on validation set. |
|
|
|
|
|
Args: |
|
|
y_true: Ground truth binary labels |
|
|
y_scores: Predicted probability scores |
|
|
metric: Metric to optimize ("f1", "precision", "recall", "balanced_accuracy") |
|
|
|
|
|
Returns: |
|
|
Best threshold and metrics at that threshold |
|
|
""" |
|
|
fpr, tpr, thresholds = roc_curve(y_true, y_scores) |
|
|
|
|
|
best_threshold = 0.5 |
|
|
best_score = 0.0 |
|
|
best_metrics = {} |
|
|
|
|
|
for threshold in thresholds: |
|
|
y_pred = (y_scores >= threshold).astype(int) |
|
|
|
|
|
|
|
|
tp = np.sum((y_pred == 1) & (y_true == 1)) |
|
|
fp = np.sum((y_pred == 1) & (y_true == 0)) |
|
|
fn = np.sum((y_pred == 0) & (y_true == 1)) |
|
|
tn = np.sum((y_pred == 0) & (y_true == 0)) |
|
|
|
|
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
|
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
|
|
f1 = 2 * (precision * recall) / (precision + |
|
|
recall) if (precision + recall) > 0 else 0.0 |
|
|
balanced_accuracy = (tpr[np.argmax(thresholds >= threshold)] + |
|
|
(1 - fpr[np.argmax(thresholds >= threshold)])) / 2 |
|
|
|
|
|
score_map = { |
|
|
"f1": f1, |
|
|
"precision": precision, |
|
|
"recall": recall, |
|
|
"balanced_accuracy": balanced_accuracy |
|
|
} |
|
|
|
|
|
score = score_map.get(metric, f1) |
|
|
|
|
|
if score > best_score: |
|
|
best_score = score |
|
|
best_threshold = threshold |
|
|
best_metrics = { |
|
|
"threshold": threshold, |
|
|
"precision": precision, |
|
|
"recall": recall, |
|
|
"f1": f1, |
|
|
"balanced_accuracy": balanced_accuracy, |
|
|
"tp": int(tp), |
|
|
"fp": int(fp), |
|
|
"tn": int(tn), |
|
|
"fn": int(fn) |
|
|
} |
|
|
|
|
|
return best_threshold, best_metrics |
|
|
|
|
|
|
|
|
def train_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, |
|
|
optimizer: optim.Optimizer, device: str) -> float: |
|
|
"""Train for one epoch.""" |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
for inputs, labels in dataloader: |
|
|
pixel_values = inputs["pixel_values"].to(device) |
|
|
labels = labels.to(device) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
logits = model(pixel_values=pixel_values) |
|
|
loss = criterion(logits, labels) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
return total_loss / num_batches if num_batches > 0 else 0.0 |
|
|
|
|
|
|
|
|
def evaluate(model: nn.Module, dataloader: DataLoader, device: str) -> tuple[np.ndarray, np.ndarray]: |
|
|
"""Evaluate model and return predictions and labels.""" |
|
|
model.eval() |
|
|
all_proba = [] |
|
|
all_labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for inputs, labels in dataloader: |
|
|
pixel_values = inputs["pixel_values"].to(device) |
|
|
labels = labels.to(device) |
|
|
|
|
|
|
|
|
proba = model.predict_proba(pixel_values) |
|
|
all_proba.append(proba.cpu().numpy()) |
|
|
all_labels.append(labels.cpu().numpy()) |
|
|
|
|
|
proba = np.vstack(all_proba) |
|
|
labels = np.concatenate(all_labels) |
|
|
|
|
|
return proba, labels |
|
|
|
|
|
|
|
|
def train( |
|
|
num_epochs: int = 10, |
|
|
batch_size: int = 32, |
|
|
learning_rate: float = 1e-3, |
|
|
weight_decay: float = 1e-4, |
|
|
class_weight: str = "balanced" |
|
|
): |
|
|
""" |
|
|
Train DISCO model using PyTorch. |
|
|
|
|
|
Args: |
|
|
num_epochs: Number of training epochs |
|
|
batch_size: Batch size for training |
|
|
learning_rate: Learning rate for optimizer |
|
|
weight_decay: Weight decay (L2 regularization) |
|
|
class_weight: Class weighting strategy ("balanced" or None) |
|
|
""" |
|
|
print("=" * 60) |
|
|
print("DISCO Model Training (PyTorch)") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
device = "mps" if torch.backends.mps.is_available() else ( |
|
|
"cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"\nUsing device: {device}") |
|
|
|
|
|
|
|
|
print("\n[1/6] Loading dataset splits...") |
|
|
dataset = get_dataset() |
|
|
|
|
|
train_paths = [str(Path(img_path)) |
|
|
for img_path in dataset["train"]["image"]] |
|
|
val_paths = [str(Path(img_path)) for img_path in dataset["val"]["image"]] |
|
|
test_paths = [str(Path(img_path)) for img_path in dataset["test"]["image"]] |
|
|
|
|
|
train_labels = np.array(dataset["train"]["label"]) |
|
|
val_labels = np.array(dataset["val"]["label"]) |
|
|
test_labels = np.array(dataset["test"]["label"]) |
|
|
|
|
|
print(f" Train: {len(train_paths)} images") |
|
|
print(f" Val: {len(val_paths)} images") |
|
|
print(f" Test: {len(test_paths)} images") |
|
|
|
|
|
|
|
|
print("\n[2/6] Loading CLIP processor...") |
|
|
model_name = "openai/clip-vit-base-patch32" |
|
|
processor = CLIPProcessor.from_pretrained(model_name) |
|
|
print(f" Model: {model_name}") |
|
|
|
|
|
|
|
|
print("\n[3/6] Creating datasets and dataloaders...") |
|
|
train_dataset = ImageDataset(train_paths, train_labels, processor) |
|
|
val_dataset = ImageDataset(val_paths, val_labels, processor) |
|
|
test_dataset = ImageDataset(test_paths, test_labels, processor) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
|
|
test_loader = DataLoader( |
|
|
test_dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
|
|
|
|
|
|
|
|
print("\n[4/6] Initializing model...") |
|
|
config = DISCOConfig( |
|
|
clip_model_name=model_name, |
|
|
num_classes=2, |
|
|
threshold=0.5 |
|
|
) |
|
|
model = DISCO(config).to(device) |
|
|
|
|
|
|
|
|
optimizer = optim.AdamW( |
|
|
model.classifier.parameters(), |
|
|
lr=learning_rate, |
|
|
weight_decay=weight_decay |
|
|
) |
|
|
|
|
|
|
|
|
if class_weight == "balanced": |
|
|
|
|
|
class_counts = np.bincount(train_labels) |
|
|
total = len(train_labels) |
|
|
class_weights = torch.tensor([ |
|
|
total / (2 * class_counts[0]), |
|
|
total / (2 * class_counts[1]) |
|
|
], dtype=torch.float32).to(device) |
|
|
criterion = nn.CrossEntropyLoss(weight=class_weights) |
|
|
print(f" Using balanced class weights: {class_weights.cpu().numpy()}") |
|
|
else: |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
print(" Using uniform class weights") |
|
|
|
|
|
print( |
|
|
f" Trainable parameters: {sum(p.numel() for p in model.classifier.parameters() if p.requires_grad):,}") |
|
|
|
|
|
|
|
|
print("\n[5/6] Training model...") |
|
|
best_val_f1 = 0.0 |
|
|
best_model_state = None |
|
|
|
|
|
with Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
BarColumn(), |
|
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), |
|
|
TimeElapsedColumn(), |
|
|
console=None, |
|
|
) as progress: |
|
|
train_task = progress.add_task("Training", total=num_epochs) |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
|
|
|
train_loss = train_epoch( |
|
|
model, train_loader, criterion, optimizer, device) |
|
|
|
|
|
|
|
|
val_proba, val_labels_np = evaluate(model, val_loader, device) |
|
|
val_scores = val_proba[:, 1] |
|
|
val_roc_auc = roc_auc_score(val_labels_np, val_scores) |
|
|
|
|
|
|
|
|
val_pred = (val_scores >= 0.5).astype(int) |
|
|
tp = np.sum((val_pred == 1) & (val_labels_np == 1)) |
|
|
fp = np.sum((val_pred == 1) & (val_labels_np == 0)) |
|
|
fn = np.sum((val_pred == 0) & (val_labels_np == 1)) |
|
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
|
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
|
|
|
|
|
val_f1 = 2 * (precision * recall) / (precision + |
|
|
recall) if (precision + recall) > 0 else 0.0 |
|
|
|
|
|
progress.update(train_task, advance=1, description=f"Epoch {epoch+1}/{num_epochs} | Loss: {train_loss:.4f} | " |
|
|
f"Val ROC-AUC: {val_roc_auc:.4f} | Val F1: {val_f1:.4f}") |
|
|
|
|
|
|
|
|
if val_f1 > best_val_f1: |
|
|
best_val_f1 = val_f1 |
|
|
best_model_state = model.state_dict().copy() |
|
|
|
|
|
|
|
|
if best_model_state is not None: |
|
|
model.load_state_dict(best_model_state) |
|
|
print(f"\n Best validation F1: {best_val_f1:.4f}") |
|
|
|
|
|
|
|
|
print("\n[6/6] Tuning threshold on validation set...") |
|
|
val_proba, val_labels_np = evaluate(model, val_loader, device) |
|
|
val_scores = val_proba[:, 1] |
|
|
best_threshold, threshold_metrics = tune_threshold( |
|
|
val_labels_np, val_scores, metric="f1") |
|
|
print(f" Best threshold: {best_threshold:.4f}") |
|
|
print(" Validation metrics at best threshold:") |
|
|
print(f" Precision: {threshold_metrics['precision']:.4f}") |
|
|
print(f" Recall: {threshold_metrics['recall']:.4f}") |
|
|
print(f" F1: {threshold_metrics['f1']:.4f}") |
|
|
print( |
|
|
f" Balanced Accuracy: {threshold_metrics['balanced_accuracy']:.4f}") |
|
|
|
|
|
|
|
|
model.threshold = best_threshold |
|
|
config.threshold = best_threshold |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Test Set Evaluation") |
|
|
print("=" * 60) |
|
|
|
|
|
test_proba, test_labels_np = evaluate(model, test_loader, device) |
|
|
test_scores = test_proba[:, 1] |
|
|
test_roc_auc = roc_auc_score(test_labels_np, test_scores) |
|
|
test_pr_auc = average_precision_score(test_labels_np, test_scores) |
|
|
|
|
|
print("\nTest Set Metrics (probability scores):") |
|
|
print(f" ROC AUC: {test_roc_auc:.4f}") |
|
|
print(f" PR AUC: {test_pr_auc:.4f}") |
|
|
|
|
|
|
|
|
test_pred = (test_scores >= best_threshold).astype(int) |
|
|
|
|
|
print(f"\nTest Set Metrics (with threshold={best_threshold:.4f}):") |
|
|
print(classification_report(test_labels_np, test_pred, |
|
|
target_names=["FAMILY_SAFE/UNCERTAIN", "SUGGESTIVE"])) |
|
|
|
|
|
|
|
|
tp = np.sum((test_pred == 1) & (test_labels_np == 1)) |
|
|
fp = np.sum((test_pred == 1) & (test_labels_np == 0)) |
|
|
tn = np.sum((test_pred == 0) & (test_labels_np == 0)) |
|
|
fn = np.sum((test_pred == 0) & (test_labels_np == 1)) |
|
|
|
|
|
print("\nConfusion Matrix:") |
|
|
print(" Predicted") |
|
|
print(" FAMILY_SAFE SUGGESTIVE") |
|
|
print(f"Actual FAMILY_SAFE {tn:4d} {fp:4d}") |
|
|
print(f" SUGGESTIVE {fn:4d} {tp:4d}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Saving Model") |
|
|
print("=" * 60) |
|
|
|
|
|
models_dir = Path(__file__).parent.parent / "models" |
|
|
models_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
config.save_pretrained(models_dir) |
|
|
model.save_pretrained(models_dir) |
|
|
print(f" Saved Hugging Face model to: {models_dir}") |
|
|
|
|
|
|
|
|
processor.save_pretrained(models_dir) |
|
|
print(f" Saved processor to: {models_dir}") |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"model_name": model_name, |
|
|
"threshold": float(best_threshold), |
|
|
"test_roc_auc": float(test_roc_auc), |
|
|
"test_pr_auc": float(test_pr_auc), |
|
|
"val_roc_auc": float(roc_auc_score(val_labels_np, val_scores)), |
|
|
"val_pr_auc": float(average_precision_score(val_labels_np, val_scores)), |
|
|
"threshold_metrics": { |
|
|
k: float(v) if isinstance(v, (int, float, np.number)) else v |
|
|
for k, v in threshold_metrics.items() |
|
|
}, |
|
|
"embedding_dim": int(model.clip_model.config.projection_dim), |
|
|
"model_type": "clip_nsfw_detector", |
|
|
"framework": "pytorch", |
|
|
"training_config": { |
|
|
"num_epochs": num_epochs, |
|
|
"batch_size": batch_size, |
|
|
"learning_rate": learning_rate, |
|
|
"weight_decay": weight_decay, |
|
|
"class_weight": class_weight |
|
|
} |
|
|
} |
|
|
|
|
|
metadata_path = models_dir / "model_metadata.json" |
|
|
with open(metadata_path, "w") as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
print(f" Saved metadata to: {metadata_path}") |
|
|
|
|
|
print("\nModel saved successfully!") |
|
|
print(f"\nModel is ready for Hugging Face upload from: {models_dir}") |
|
|
|
|
|
return { |
|
|
"model": model, |
|
|
"threshold": best_threshold, |
|
|
"test_roc_auc": test_roc_auc, |
|
|
"test_pr_auc": test_pr_auc, |
|
|
"threshold_metrics": threshold_metrics, |
|
|
"metadata_path": metadata_path |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
results = train() |
|
|
|