"""Training loop for CLIPSeg fine-tuning.""" import json import time from pathlib import Path import numpy as np import torch import yaml from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader from tqdm import tqdm from src.data.dataset import DrywallSegDataset, collate_fn from src.model.clipseg_wrapper import load_model_and_processor from src.model.losses import BCEDiceLoss PROJECT_ROOT = Path(__file__).resolve().parents[1] def compute_metrics(logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5): """Compute mIoU and Dice for a batch.""" preds = (torch.sigmoid(logits) > threshold).float() targets = (targets > 0.5).float() intersection = (preds * targets).sum(dim=(1, 2)) union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) - intersection iou = (intersection + 1e-6) / (union + 1e-6) dice = (2 * intersection + 1e-6) / (preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) + 1e-6) return {"miou": iou.mean().item(), "dice": dice.mean().item()} def get_device(): """Select best available device.""" if torch.backends.mps.is_available(): return torch.device("mps") if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def train(config_path: str | None = None): config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml") with open(config_path) as f: config = yaml.safe_load(f) # Seed seed = config["seed"] torch.manual_seed(seed) np.random.seed(seed) device = get_device() print(f"Device: {device}") # Model model, processor = load_model_and_processor( config["model"]["name"], config["model"]["freeze_backbone"], ) model = model.to(device) # Data splits_dir = PROJECT_ROOT / "data" / "splits" train_ds = DrywallSegDataset(str(splits_dir / "train.json"), processor, config["data"]["image_size"]) val_ds = DrywallSegDataset(str(splits_dir / "val.json"), processor, config["data"]["image_size"]) tc = config["training"] train_loader = DataLoader(train_ds, batch_size=tc["batch_size"], shuffle=True, collate_fn=collate_fn, num_workers=tc["num_workers"]) val_loader = DataLoader(val_ds, batch_size=tc["batch_size"], shuffle=False, collate_fn=collate_fn, num_workers=tc["num_workers"]) # Loss, optimizer, scheduler criterion = BCEDiceLoss(tc["bce_weight"], tc["dice_weight"]) optimizer = AdamW( [p for p in model.parameters() if p.requires_grad], lr=tc["lr"], weight_decay=tc["weight_decay"], ) scheduler = CosineAnnealingLR(optimizer, T_max=tc["epochs"]) # Training state best_miou = 0.0 patience_counter = 0 history = {"train_loss": [], "val_loss": [], "val_miou": [], "val_dice": []} ckpt_dir = PROJECT_ROOT / "outputs" / "checkpoints" ckpt_dir.mkdir(parents=True, exist_ok=True) log_dir = PROJECT_ROOT / "outputs" / "logs" log_dir.mkdir(parents=True, exist_ok=True) start_time = time.time() for epoch in range(1, tc["epochs"] + 1): # ---- Train ---- model.train() train_losses = [] for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{tc['epochs']} [train]", leave=False): pixel_values = batch["pixel_values"].to(device) input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, ) logits = outputs.logits loss = criterion(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() train_losses.append(loss.item()) scheduler.step() avg_train_loss = np.mean(train_losses) # ---- Validate ---- model.eval() val_losses, val_mious, val_dices = [], [], [] with torch.no_grad(): for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{tc['epochs']} [val]", leave=False): pixel_values = batch["pixel_values"].to(device) input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, ) logits = outputs.logits loss = criterion(logits, labels) metrics = compute_metrics(logits, labels) val_losses.append(loss.item()) val_mious.append(metrics["miou"]) val_dices.append(metrics["dice"]) avg_val_loss = np.mean(val_losses) avg_val_miou = np.mean(val_mious) avg_val_dice = np.mean(val_dices) history["train_loss"].append(float(avg_train_loss)) history["val_loss"].append(float(avg_val_loss)) history["val_miou"].append(float(avg_val_miou)) history["val_dice"].append(float(avg_val_dice)) print(f"Epoch {epoch:3d} | train_loss={avg_train_loss:.4f} | val_loss={avg_val_loss:.4f} | " f"val_mIoU={avg_val_miou:.4f} | val_Dice={avg_val_dice:.4f}") # Checkpoint if avg_val_miou > best_miou: best_miou = avg_val_miou patience_counter = 0 torch.save(model.state_dict(), ckpt_dir / "best_model.pt") print(f" -> New best mIoU: {best_miou:.4f}, saved checkpoint") else: patience_counter += 1 if patience_counter >= tc["patience"]: print(f" Early stopping at epoch {epoch} (patience={tc['patience']})") break total_time = time.time() - start_time # Save history & summary with open(log_dir / "training_history.json", "w") as f: json.dump(history, f, indent=2) summary = { "total_epochs": epoch, "best_val_miou": float(best_miou), "total_time_seconds": round(total_time, 1), "total_time_minutes": round(total_time / 60, 1), "device": str(device), "train_samples": len(train_ds), "val_samples": len(val_ds), "seed": seed, } with open(log_dir / "training_summary.json", "w") as f: json.dump(summary, f, indent=2) print(f"\nTraining complete in {summary['total_time_minutes']} min") print(f"Best val mIoU: {best_miou:.4f}") return model, history if __name__ == "__main__": train()