drywall-qa-clipseg / src /train.py
youngPhilosopher's picture
Upload folder using huggingface_hub
b891e61 verified
"""Training loop for CLIPSeg fine-tuning."""
import json
import time
from pathlib import Path
import numpy as np
import torch
import yaml
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.data.dataset import DrywallSegDataset, collate_fn
from src.model.clipseg_wrapper import load_model_and_processor
from src.model.losses import BCEDiceLoss
PROJECT_ROOT = Path(__file__).resolve().parents[1]
def compute_metrics(logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5):
"""Compute mIoU and Dice for a batch."""
preds = (torch.sigmoid(logits) > threshold).float()
targets = (targets > 0.5).float()
intersection = (preds * targets).sum(dim=(1, 2))
union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
dice = (2 * intersection + 1e-6) / (preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) + 1e-6)
return {"miou": iou.mean().item(), "dice": dice.mean().item()}
def get_device():
"""Select best available device."""
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def train(config_path: str | None = None):
config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml")
with open(config_path) as f:
config = yaml.safe_load(f)
# Seed
seed = config["seed"]
torch.manual_seed(seed)
np.random.seed(seed)
device = get_device()
print(f"Device: {device}")
# Model
model, processor = load_model_and_processor(
config["model"]["name"],
config["model"]["freeze_backbone"],
)
model = model.to(device)
# Data
splits_dir = PROJECT_ROOT / "data" / "splits"
train_ds = DrywallSegDataset(str(splits_dir / "train.json"), processor, config["data"]["image_size"])
val_ds = DrywallSegDataset(str(splits_dir / "val.json"), processor, config["data"]["image_size"])
tc = config["training"]
train_loader = DataLoader(train_ds, batch_size=tc["batch_size"], shuffle=True,
collate_fn=collate_fn, num_workers=tc["num_workers"])
val_loader = DataLoader(val_ds, batch_size=tc["batch_size"], shuffle=False,
collate_fn=collate_fn, num_workers=tc["num_workers"])
# Loss, optimizer, scheduler
criterion = BCEDiceLoss(tc["bce_weight"], tc["dice_weight"])
optimizer = AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=tc["lr"],
weight_decay=tc["weight_decay"],
)
scheduler = CosineAnnealingLR(optimizer, T_max=tc["epochs"])
# Training state
best_miou = 0.0
patience_counter = 0
history = {"train_loss": [], "val_loss": [], "val_miou": [], "val_dice": []}
ckpt_dir = PROJECT_ROOT / "outputs" / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
log_dir = PROJECT_ROOT / "outputs" / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
start_time = time.time()
for epoch in range(1, tc["epochs"] + 1):
# ---- Train ----
model.train()
train_losses = []
for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{tc['epochs']} [train]", leave=False):
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
)
logits = outputs.logits
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses.append(loss.item())
scheduler.step()
avg_train_loss = np.mean(train_losses)
# ---- Validate ----
model.eval()
val_losses, val_mious, val_dices = [], [], []
with torch.no_grad():
for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{tc['epochs']} [val]", leave=False):
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
)
logits = outputs.logits
loss = criterion(logits, labels)
metrics = compute_metrics(logits, labels)
val_losses.append(loss.item())
val_mious.append(metrics["miou"])
val_dices.append(metrics["dice"])
avg_val_loss = np.mean(val_losses)
avg_val_miou = np.mean(val_mious)
avg_val_dice = np.mean(val_dices)
history["train_loss"].append(float(avg_train_loss))
history["val_loss"].append(float(avg_val_loss))
history["val_miou"].append(float(avg_val_miou))
history["val_dice"].append(float(avg_val_dice))
print(f"Epoch {epoch:3d} | train_loss={avg_train_loss:.4f} | val_loss={avg_val_loss:.4f} | "
f"val_mIoU={avg_val_miou:.4f} | val_Dice={avg_val_dice:.4f}")
# Checkpoint
if avg_val_miou > best_miou:
best_miou = avg_val_miou
patience_counter = 0
torch.save(model.state_dict(), ckpt_dir / "best_model.pt")
print(f" -> New best mIoU: {best_miou:.4f}, saved checkpoint")
else:
patience_counter += 1
if patience_counter >= tc["patience"]:
print(f" Early stopping at epoch {epoch} (patience={tc['patience']})")
break
total_time = time.time() - start_time
# Save history & summary
with open(log_dir / "training_history.json", "w") as f:
json.dump(history, f, indent=2)
summary = {
"total_epochs": epoch,
"best_val_miou": float(best_miou),
"total_time_seconds": round(total_time, 1),
"total_time_minutes": round(total_time / 60, 1),
"device": str(device),
"train_samples": len(train_ds),
"val_samples": len(val_ds),
"seed": seed,
}
with open(log_dir / "training_summary.json", "w") as f:
json.dump(summary, f, indent=2)
print(f"\nTraining complete in {summary['total_time_minutes']} min")
print(f"Best val mIoU: {best_miou:.4f}")
return model, history
if __name__ == "__main__":
train()