| """Training loop for CLIPSeg fine-tuning.""" |
|
|
| import json |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import yaml |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from src.data.dataset import DrywallSegDataset, collate_fn |
| from src.model.clipseg_wrapper import load_model_and_processor |
| from src.model.losses import BCEDiceLoss |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
|
|
|
|
| def compute_metrics(logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5): |
| """Compute mIoU and Dice for a batch.""" |
| preds = (torch.sigmoid(logits) > threshold).float() |
| targets = (targets > 0.5).float() |
|
|
| intersection = (preds * targets).sum(dim=(1, 2)) |
| union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) - intersection |
| iou = (intersection + 1e-6) / (union + 1e-6) |
|
|
| dice = (2 * intersection + 1e-6) / (preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) + 1e-6) |
|
|
| return {"miou": iou.mean().item(), "dice": dice.mean().item()} |
|
|
|
|
| def get_device(): |
| """Select best available device.""" |
| if torch.backends.mps.is_available(): |
| return torch.device("mps") |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| return torch.device("cpu") |
|
|
|
|
| def train(config_path: str | None = None): |
| config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml") |
| with open(config_path) as f: |
| config = yaml.safe_load(f) |
|
|
| |
| seed = config["seed"] |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| device = get_device() |
| print(f"Device: {device}") |
|
|
| |
| model, processor = load_model_and_processor( |
| config["model"]["name"], |
| config["model"]["freeze_backbone"], |
| ) |
| model = model.to(device) |
|
|
| |
| splits_dir = PROJECT_ROOT / "data" / "splits" |
| train_ds = DrywallSegDataset(str(splits_dir / "train.json"), processor, config["data"]["image_size"]) |
| val_ds = DrywallSegDataset(str(splits_dir / "val.json"), processor, config["data"]["image_size"]) |
|
|
| tc = config["training"] |
| train_loader = DataLoader(train_ds, batch_size=tc["batch_size"], shuffle=True, |
| collate_fn=collate_fn, num_workers=tc["num_workers"]) |
| val_loader = DataLoader(val_ds, batch_size=tc["batch_size"], shuffle=False, |
| collate_fn=collate_fn, num_workers=tc["num_workers"]) |
|
|
| |
| criterion = BCEDiceLoss(tc["bce_weight"], tc["dice_weight"]) |
| optimizer = AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=tc["lr"], |
| weight_decay=tc["weight_decay"], |
| ) |
| scheduler = CosineAnnealingLR(optimizer, T_max=tc["epochs"]) |
|
|
| |
| best_miou = 0.0 |
| patience_counter = 0 |
| history = {"train_loss": [], "val_loss": [], "val_miou": [], "val_dice": []} |
| ckpt_dir = PROJECT_ROOT / "outputs" / "checkpoints" |
| ckpt_dir.mkdir(parents=True, exist_ok=True) |
| log_dir = PROJECT_ROOT / "outputs" / "logs" |
| log_dir.mkdir(parents=True, exist_ok=True) |
|
|
| start_time = time.time() |
|
|
| for epoch in range(1, tc["epochs"] + 1): |
| |
| model.train() |
| train_losses = [] |
| for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{tc['epochs']} [train]", leave=False): |
| pixel_values = batch["pixel_values"].to(device) |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| outputs = model( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
| logits = outputs.logits |
| loss = criterion(logits, labels) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| train_losses.append(loss.item()) |
|
|
| scheduler.step() |
| avg_train_loss = np.mean(train_losses) |
|
|
| |
| model.eval() |
| val_losses, val_mious, val_dices = [], [], [] |
| with torch.no_grad(): |
| for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{tc['epochs']} [val]", leave=False): |
| pixel_values = batch["pixel_values"].to(device) |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| outputs = model( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
| logits = outputs.logits |
| loss = criterion(logits, labels) |
| metrics = compute_metrics(logits, labels) |
|
|
| val_losses.append(loss.item()) |
| val_mious.append(metrics["miou"]) |
| val_dices.append(metrics["dice"]) |
|
|
| avg_val_loss = np.mean(val_losses) |
| avg_val_miou = np.mean(val_mious) |
| avg_val_dice = np.mean(val_dices) |
|
|
| history["train_loss"].append(float(avg_train_loss)) |
| history["val_loss"].append(float(avg_val_loss)) |
| history["val_miou"].append(float(avg_val_miou)) |
| history["val_dice"].append(float(avg_val_dice)) |
|
|
| print(f"Epoch {epoch:3d} | train_loss={avg_train_loss:.4f} | val_loss={avg_val_loss:.4f} | " |
| f"val_mIoU={avg_val_miou:.4f} | val_Dice={avg_val_dice:.4f}") |
|
|
| |
| if avg_val_miou > best_miou: |
| best_miou = avg_val_miou |
| patience_counter = 0 |
| torch.save(model.state_dict(), ckpt_dir / "best_model.pt") |
| print(f" -> New best mIoU: {best_miou:.4f}, saved checkpoint") |
| else: |
| patience_counter += 1 |
| if patience_counter >= tc["patience"]: |
| print(f" Early stopping at epoch {epoch} (patience={tc['patience']})") |
| break |
|
|
| total_time = time.time() - start_time |
|
|
| |
| with open(log_dir / "training_history.json", "w") as f: |
| json.dump(history, f, indent=2) |
|
|
| summary = { |
| "total_epochs": epoch, |
| "best_val_miou": float(best_miou), |
| "total_time_seconds": round(total_time, 1), |
| "total_time_minutes": round(total_time / 60, 1), |
| "device": str(device), |
| "train_samples": len(train_ds), |
| "val_samples": len(val_ds), |
| "seed": seed, |
| } |
| with open(log_dir / "training_summary.json", "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| print(f"\nTraining complete in {summary['total_time_minutes']} min") |
| print(f"Best val mIoU: {best_miou:.4f}") |
| return model, history |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|