Spaces:
Sleeping
Sleeping
functionNormally
Restructurer l'app : backbone préentraîné + ML classique + FC head + CNN de zéro
cdc317a | """ | |
| finetune_backbone.py | |
| Fine-tune ResNet18 (ImageNet) on the local charcoal microscopy dataset. | |
| Goal: produce a domain-adapted backbone for students to use as a frozen | |
| feature extractor. The full dataset is used intentionally — this is a | |
| teaching artifact, not a research model with a held-out test split. | |
| Output (in backbone/): | |
| resnet18_charcoal_backbone.pt — backbone weights, FC replaced by Identity | |
| backbone_meta.json — class names, feature dim, training info | |
| Usage: | |
| python finetune_backbone.py | |
| python finetune_backbone.py --epochs 40 --batch-size 16 | |
| """ | |
| import argparse | |
| import json | |
| import time | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import models, transforms | |
| # --------------------------------------------------------------------------- | |
| # Paths | |
| # --------------------------------------------------------------------------- | |
| ROOT = Path(__file__).parent | |
| DATA_DIR = ROOT / "data" | |
| OUTPUT_DIR = ROOT / "backbone" | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Defaults | |
| # --------------------------------------------------------------------------- | |
| IMAGE_SIZE = 224 | |
| SEED = 42 | |
| WARMUP_EPOCHS = 10 # backbone frozen, only FC trained | |
| WARMUP_LR = 1e-3 | |
| FINETUNE_EPOCHS = 40 # all layers unfrozen, small LR | |
| FINETUNE_LR = 5e-5 | |
| WEIGHT_DECAY = 1e-4 | |
| # --------------------------------------------------------------------------- | |
| # Dataset | |
| # --------------------------------------------------------------------------- | |
| class CharcoalDataset(Dataset): | |
| """Flat ImageFolder-style dataset that handles .tif files.""" | |
| EXTENSIONS = {".tif", ".tiff", ".jpg", ".jpeg", ".png"} | |
| def __init__(self, root: Path, transform=None): | |
| self.transform = transform | |
| self.classes = sorted( | |
| d.name for d in root.iterdir() | |
| if d.is_dir() and not d.name.startswith(".") | |
| ) | |
| self.class_to_idx = {c: i for i, c in enumerate(self.classes)} | |
| self.samples = [] | |
| for cls in self.classes: | |
| for p in sorted((root / cls).iterdir()): | |
| if p.suffix.lower() in self.EXTENSIONS: | |
| self.samples.append((p, self.class_to_idx[cls])) | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| path, label = self.samples[idx] | |
| image = Image.open(path).convert("RGB") | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def make_transform(): | |
| # Aggressive augmentation: microscopy images have no canonical orientation | |
| # and vary in staining intensity. | |
| return transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomVerticalFlip(), | |
| transforms.RandomRotation(180), | |
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), | |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.85, 1.15)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ]) | |
| # --------------------------------------------------------------------------- | |
| # Training helpers | |
| # --------------------------------------------------------------------------- | |
| def run_epoch(model, loader, criterion, optimizer, device): | |
| model.train() | |
| total_loss, correct, total = 0.0, 0, 0 | |
| for images, labels in loader: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| total_loss += loss.item() * images.size(0) | |
| correct += (outputs.argmax(1) == labels).sum().item() | |
| total += labels.size(0) | |
| return total_loss / total, correct / total | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--warmup-epochs", type=int, default=WARMUP_EPOCHS) | |
| parser.add_argument("--finetune-epochs", type=int, default=FINETUNE_EPOCHS) | |
| parser.add_argument("--batch-size", type=int, default=8) | |
| parser.add_argument("--warmup-lr", type=float, default=WARMUP_LR) | |
| parser.add_argument("--finetune-lr", type=float, default=FINETUNE_LR) | |
| args = parser.parse_args() | |
| torch.manual_seed(SEED) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Device : {device}") | |
| dataset = CharcoalDataset(DATA_DIR, transform=make_transform()) | |
| num_classes = len(dataset.classes) | |
| print(f"Classes : {num_classes} | Images : {len(dataset)}") | |
| print(f" {', '.join(dataset.classes)}\n") | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=0, # 0 = safe on Windows | |
| pin_memory=(device.type == "cuda"), | |
| ) | |
| # ----------------------------------------------------------------------- | |
| # Build model | |
| # ----------------------------------------------------------------------- | |
| model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) | |
| model.fc = nn.Linear(model.fc.in_features, num_classes) | |
| model.to(device) | |
| # Label smoothing helps regularise with tiny datasets | |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1) | |
| # ----------------------------------------------------------------------- | |
| # Phase 1 — warm-up: freeze backbone, train FC only | |
| # ----------------------------------------------------------------------- | |
| print(f"=== Phase 1 : warm-up ({args.warmup_epochs} epochs, backbone frozen) ===") | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| for p in model.fc.parameters(): | |
| p.requires_grad = True | |
| optimizer = optim.AdamW(model.fc.parameters(), lr=args.warmup_lr, weight_decay=WEIGHT_DECAY) | |
| for epoch in range(1, args.warmup_epochs + 1): | |
| loss, acc = run_epoch(model, loader, criterion, optimizer, device) | |
| print(f" [{epoch:>3}/{args.warmup_epochs}] loss={loss:.4f} acc={acc:.4f}") | |
| # ----------------------------------------------------------------------- | |
| # Phase 2 — full fine-tune: unfreeze all layers | |
| # ----------------------------------------------------------------------- | |
| print(f"\n=== Phase 2 : fine-tune ({args.finetune_epochs} epochs, all layers) ===") | |
| for p in model.parameters(): | |
| p.requires_grad = True | |
| optimizer = optim.AdamW( | |
| model.parameters(), lr=args.finetune_lr, weight_decay=WEIGHT_DECAY | |
| ) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=args.finetune_epochs, eta_min=args.finetune_lr * 0.05 | |
| ) | |
| best_acc = 0.0 | |
| best_state = None | |
| t0 = time.time() | |
| for epoch in range(1, args.finetune_epochs + 1): | |
| loss, acc = run_epoch(model, loader, criterion, optimizer, device) | |
| scheduler.step() | |
| lr = optimizer.param_groups[0]["lr"] | |
| print(f" [{epoch:>3}/{args.finetune_epochs}] loss={loss:.4f} acc={acc:.4f} lr={lr:.2e}") | |
| if acc > best_acc: | |
| best_acc = acc | |
| best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} | |
| elapsed = time.time() - t0 | |
| print(f"\nTemps phase 2 : {elapsed:.0f}s | Meilleure accuracy entraînement : {best_acc:.4f}") | |
| # ----------------------------------------------------------------------- | |
| # Save backbone (FC replaced by Identity — outputs 512-dim feature vector) | |
| # ----------------------------------------------------------------------- | |
| model.load_state_dict(best_state) | |
| backbone = models.resnet18() | |
| backbone.fc = nn.Identity() | |
| # Transfer all weights except fc (which is now Identity with no parameters) | |
| backbone_state = {k: v for k, v in best_state.items() if not k.startswith("fc.")} | |
| backbone.load_state_dict(backbone_state, strict=False) | |
| backbone_path = OUTPUT_DIR / "resnet18_charcoal_backbone.pt" | |
| torch.save(backbone.state_dict(), backbone_path) | |
| print(f"Backbone sauvegardé : {backbone_path}") | |
| # ----------------------------------------------------------------------- | |
| # Save metadata | |
| # ----------------------------------------------------------------------- | |
| meta = { | |
| "classes": dataset.classes, | |
| "num_classes": num_classes, | |
| "image_size": IMAGE_SIZE, | |
| "feature_dim": 512, | |
| "warmup_epochs": args.warmup_epochs, | |
| "finetune_epochs": args.finetune_epochs, | |
| "best_train_acc": round(float(best_acc), 4), | |
| "device": str(device), | |
| } | |
| meta_path = OUTPUT_DIR / "backbone_meta.json" | |
| with open(meta_path, "w", encoding="utf-8") as f: | |
| json.dump(meta, f, indent=2, ensure_ascii=False) | |
| print(f"Métadonnées sauvegardées : {meta_path}") | |
| if __name__ == "__main__": | |
| main() | |