""" Dino V3 finetunning for script classification ============================================== Progressive finetuning with page-level train/val/test split Runs on three preprocessed variants: - whole page / - patches color / - patches_clahe/ Usage: #Exp1: whole page python finetune_dinov3.py --data_dir ./Data/output/whole_page --experiment whole_page #Exp2: patches_color python finetune_dinov3.py --data_dir ./Data/output/patches_color --experiment patches_color #Exp3: CLAHE_patches python finetune_dinov3.py --data_dir ./Data/output/patches_clahe --experiment patches_clahe Outputs (under --output_dir//): best_.pt — best val macro-F1 per stage history_stage_{a,b,c}.json — per-epoch metrics per stage training_history_stage_{a,b,c}.png — curves per stage final_model.pt — weights chosen by best val across stages + test_metrics metadata results.json, confusion_matrix.*, training_history.png (full run) Requirements: pip install torch torchvision transformers scikit-learn matplotlib seaborn # DINOv3 requires transformers >= 4.56.0 # If not available: pip install --upgrade git+https://github.com/huggingface/transformers.git """ import os import re import json import argparse import random from pathlib import Path from collections import Counter, defaultdict from datetime import datetime import numpy as np import torch import torch.nn as nn from torch.cuda.amp import GradScaler from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler from torchvision import transforms from PIL import Image from sklearn.metrics import (classification_report, confusion_matrix, f1_score, accuracy_score ) try: from transformers import AutoImageProcessor, AutoModel except ImportError: raise ImportError("transformers >= 4.56.0 required for DINOv3.\n" "Install: pip install --upgrade git+https://github.com/huggingface/transformers.git" ) # ===================== # CONFIG # ===================== DINOV3_MODEL_ID = "facebook/dinov3-vits16-pretrain-lvd1689m" EMBEDDING_DIM = 384 VALID_EXT = {'.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp', '.webp'} SEED = 42 def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # ====================== # Page level spliting # ====================== def get_page_name(filepath): """ Extract the original page name from a patch filename. e.g., 'manuscript001_p3.png' → 'manuscript001' e.g., 'manuscript001.png' → 'manuscript001' This ensures all patches from the same page stay in the same split. """ stem = Path(filepath).stem page_name = re.sub(r'_p\d+$','',stem) return page_name def normalize_label_key(label: str) -> str: """Normalize class names for manifest lookup.""" return re.sub(r'[^a-z0-9]+', '_', label.lower()).strip('_') def load_exclusion_manifest(manifest_path: str): """ Load class->page_ids exclusions from JSON. Returns a dict keyed by normalized class labels. """ if not manifest_path: return {} path = Path(manifest_path) if not path.is_file(): print(f" Exclusion manifest not found, skipping exclusions: {path}") return {} with open(path, "r", encoding="utf-8") as f: raw = json.load(f) if not isinstance(raw, dict): raise ValueError(f"Exclusion manifest must be a JSON object: {path}") manifest = {} for label, ids in raw.items(): if not isinstance(ids, list): continue norm_label = normalize_label_key(str(label)) manifest[norm_label] = {str(x).strip() for x in ids if str(x).strip()} return manifest def create_page_level(data_dir, val_ratio=0.15, test_ratio=0.15, seed=SEED, excluded_pages_by_label=None): """ Split at the PAGE level, not the image/patch level. All patches from one page go into the same split. Returns: splits: dict with 'train', 'val', 'test' keys each value is a list of (filepath, label) tuples label_to_idx: dict mapping label strings to integers """ set_seed(seed) data_dir = Path(data_dir) class_pages = defaultdict(lambda: defaultdict(list)) skipped_by_label = Counter() for cls_dir in sorted(data_dir.iterdir()): if not cls_dir.is_dir() or cls_dir.name.startswith('.'): continue label = cls_dir.name excluded_pages = set() if excluded_pages_by_label: excluded_pages = excluded_pages_by_label.get(normalize_label_key(label), set()) for img_path in sorted(cls_dir.iterdir()): if img_path.suffix.lower() in VALID_EXT: page = get_page_name(str(img_path)) if page in excluded_pages: skipped_by_label[label] += 1 continue class_pages[label][page].append(str(img_path)) # Create label mapping labels = sorted(class_pages.keys()) label_to_idx = {label: idx for idx, label in enumerate(labels)} idx_to_label = {idx: label for label, idx in label_to_idx.items()} # Split pages per class (stratified) splits = {'train': [], 'val': [], 'test': []} for label in labels: pages = list(class_pages[label].keys()) random.shuffle(pages) n_pages = len(pages) n_test = max(1, int(n_pages * test_ratio)) n_val = max(1, int(n_pages * val_ratio)) n_train = n_pages - n_test - n_val test_pages = pages[:n_test] val_pages = pages[n_test:n_test + n_val] train_pages = pages[n_test + n_val:] for page in train_pages: for fpath in class_pages[label][page]: splits['train'].append((fpath, label)) for page in val_pages: for fpath in class_pages[label][page]: splits['val'].append((fpath, label)) for page in test_pages: for fpath in class_pages[label][page]: splits['test'].append((fpath, label)) return splits, label_to_idx, idx_to_label, dict(skipped_by_label) class ScriptDataset(Dataset): def __init__(self, samples, label_to_idx, processor, augment = False): self.samples = samples self.label_to_idx = label_to_idx self.processor = processor self.augment = augment #Document aware augmentation if augment: self.aug_transform = transforms.Compose([ transforms.RandomRotation(degrees=5, fill=255), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.RandomResizedCrop(224, scale=(0.7, 1.0), ratio=(0.9, 1.1)), transforms.RandomErasing(p=0.1, scale=(0.02, 0.08)), ]) else: self.aug_transform = None def __len__(self): return len(self.samples) def __getitem__(self, idx): file_path,label_str = self.samples[idx] #Load Image img = Image.open(file_path).convert('RGB') if self.aug_transform is not None and self.augment: img = transforms.ToTensor()(img) img = self.aug_transform(img) img = transforms.ToPILImage()(img) # Process with DINOv3 processor (resize, normalize) inputs = self.processor(images=img, return_tensors="pt") pixel_values = inputs['pixel_values'].squeeze(0) label_idx = self.label_to_idx[label_str] return pixel_values, label_idx class DINOv3Classifier(nn.Module): """ DINOv3 ViT-S backbone + MLP classification head. The backbone outputs: - CLS token: 384-dim embedding (used for classification) - Patch tokens: 196 × 384-dim (not used in this version) - Register tokens: 4 × 384-dim (not used) Classification head: 384 → 128 → num_classes """ def __init__(self, model_id, num_classes, dropout=0.1): super().__init__() #Load pretrained backbone self.backbone = AutoModel.from_pretrained(model_id) #Get embedding dim hidden_size = self.backbone.config.hidden_size #Classification head self.head = nn.Sequential( nn.LayerNorm(hidden_size), nn.Dropout(dropout), nn.Linear(hidden_size, 128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, num_classes), ) self.freeze_backbone() def freeze_backbone(self): """Freeze all the backbone paramenters""" for params in self.backbone.parameters(): params.requires_grad = False def unfreeze_last_n_blocks(self, n): """ Unfreeze the last N transformer blocks. DINOv3 ViT-S has 12 blocks (layers). """ # First freeze everything self.freeze_backbone() # HF DINOv3ViTModel: blocks at backbone.model.layer, final norm at backbone.norm # (not ViT/BERT-style backbone.encoder.layer). if hasattr(self.backbone, "model") and hasattr(self.backbone.model, "layer"): layers = self.backbone.model.layer elif hasattr(self.backbone, "encoder") and hasattr(self.backbone.encoder, "layer"): layers = self.backbone.encoder.layer else: raise AttributeError( "Backbone has no recognizable transformer blocks " "(expected .model.layer for DINOv3 or .encoder.layer for ViT/BERT)." ) total_layers = len(layers) for i in range(max(0, total_layers - n), total_layers): for param in layers[i].parameters(): param.requires_grad = True if hasattr(self.backbone, "norm"): for param in self.backbone.norm.parameters(): param.requires_grad = True elif hasattr(self.backbone, "layernorm"): for param in self.backbone.layernorm.parameters(): param.requires_grad = True def forward(self, pixel_values): # Get backbone outputs outputs = self.backbone(pixel_values=pixel_values) # Use CLS token (first token) cls_embedding = outputs.last_hidden_state[:, 0, :] # Classify logits = self.head(cls_embedding) return logits # ==================================== # Tranining # ==================================== def get_class_weights(samples, label_to_idx, device): """Compute inverse-frequency class weights for balanced training.""" counts = Counter(label for _, label in samples) total = sum(counts.values()) weights = torch.zeros(len(label_to_idx), device=device) for label, idx in label_to_idx.items(): cnt = max(counts.get(label, 1), 1) weights[idx] = total / (len(label_to_idx) * cnt) return weights def get_weighted_sampler(samples, label_to_idx): """WeightedRandomSampler for balanced batches.""" counts = Counter(label for _, label in samples) total = sum(counts.values()) class_weights = {label: total / count for label, count in counts.items()} sample_weights = [class_weights[label] for _, label in samples] return WeightedRandomSampler(sample_weights, len(samples), replacement=True) def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None): """Train for one epoch with optional mixed precision.""" model.train() total_loss = 0 correct = 0 total = 0 for batch_idx, (images,labels) in enumerate(loader): images = images.to(device) labels = labels.to(device) optimizer.zero_grad() if scaler: with torch.autocast(device_type='cuda', dtype=torch.float16): logits = model(images) loss = criterion(logits, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: logits = model(images) loss = criterion(logits, labels) loss.backward() optimizer.step() total_loss += loss.item() * images.size(0) _, predicted = logits.max(1) correct += predicted.eq(labels).sum().item() total += labels.size(0) if(batch_idx + 1) % 50 == 0: print(f" batch {batch_idx+1}/{len(loader)} | " f"loss: {loss.item():.4f} | acc: {correct/total:.3f}") return total_loss / total, correct / total def _stage_checkpoint_slug(stage_name: str) -> str: """Stable filename fragment (no spaces/colons) for checkpoint paths.""" s = re.sub(r"[^a-z0-9]+", "_", stage_name.lower()) return re.sub(r"_+", "_", s).strip("_") @torch.no_grad() def evaluate(model, loader, criterion, device, idx_to_label=None): """Return validation/test metrics and per-sample preds, labels, probs.""" model.eval() total_loss = 0.0 total = 0 all_preds = [] all_labels = [] all_probs = [] for images, labels in loader: images = images.to(device) labels = labels.to(device) logits = model(images) loss = criterion(logits, labels) bs = images.size(0) total_loss += loss.item() * bs total += bs probs = torch.softmax(logits, dim=1) pred = logits.argmax(dim=1) all_preds.extend(pred.cpu().numpy().tolist()) all_labels.extend(labels.cpu().numpy().tolist()) all_probs.extend(probs.cpu().numpy().tolist()) avg_loss = total_loss / max(total, 1) acc = accuracy_score(all_labels, all_preds) macro_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0) weighted_f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0) metrics = { "loss": float(avg_loss), "accuracy": float(acc), "macro_f1": float(macro_f1), "weighted_f1": float(weighted_f1), } return metrics, all_preds, all_labels, all_probs def evaluate_page_level(samples, probs, label_to_idx, idx_to_label): """ Aggregate patch-level probabilities to page-level predictions. Args: samples: list of (filepath, label_str) for the evaluated split. probs: list of per-sample probability vectors (same order as samples). """ if len(samples) != len(probs): raise ValueError( f"samples/probs length mismatch: {len(samples)} != {len(probs)}" ) page_preds = defaultdict(list) page_labels = {} # Page-level true labels from file stems for filepath, label_str in samples: page = get_page_name(filepath) page_labels[page] = label_to_idx[label_str] # Group probabilities by page for (filepath, _), p in zip(samples, probs): page = get_page_name(filepath) page_preds[page].append(np.asarray(p, dtype=np.float32)) pages_sorted = sorted(page_preds.keys()) all_page_true = [] all_page_pred = [] page_avg_probs = {} for page in pages_sorted: avg_probs = np.mean(page_preds[page], axis=0) pred_idx = int(np.argmax(avg_probs)) true_idx = int(page_labels[page]) all_page_true.append(true_idx) all_page_pred.append(pred_idx) page_avg_probs[page] = avg_probs.tolist() acc = accuracy_score(all_page_true, all_page_pred) macro_f1 = f1_score(all_page_true, all_page_pred, average="macro", zero_division=0) weighted_f1 = f1_score(all_page_true, all_page_pred, average="weighted", zero_division=0) metrics = { "accuracy": float(acc), "macro_f1": float(macro_f1), "weighted_f1": float(weighted_f1), "num_pages": int(len(pages_sorted)), "num_samples": int(len(samples)), } return { "metrics": metrics, "pages": pages_sorted, "page_true": all_page_true, "page_pred": all_page_pred, "page_avg_probs": page_avg_probs, } #============================ # Progressive fine-tunning #============================ def run_stage(model, train_loader, val_loader, criterion, device, stage_name, lr_backbone, lr_head, epochs, output_dir, idx_to_label, use_amp=True): """Run one stage of progressive fine-tuning.""" print(f"\n{'='*60}") print(f" {stage_name}") print(f"{'='*60}") # Set up optimizer with different LRs for backbone and head param_groups = [] backbone_params = [p for p in model.backbone.parameters() if p.requires_grad] head_params = list(model.head.parameters()) if backbone_params: param_groups.append({'params': backbone_params, 'lr': lr_backbone}) print(f" Backbone params (trainable): {sum(p.numel() for p in backbone_params):,}") param_groups.append({'params': head_params, 'lr': lr_head}) print(f" Head params: {sum(p.numel() for p in head_params):,}") print(f" LR backbone: {lr_backbone}, LR head: {lr_head}") print(f" Epochs: {epochs}") optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) scaler = torch.amp.GradScaler() if use_amp and device.type == 'cuda' else None slug = _stage_checkpoint_slug(stage_name) checkpoint_path = output_dir / f'best_{slug}.pt' best_val_f1 = 0 best_epoch = 0 history = [] for epoch in range(epochs): print(f"\n Epoch {epoch+1}/{epochs}") train_loss, train_acc = train_one_epoch( model, train_loader, criterion, optimizer, device, scaler ) val_metrics, _, _, _ = evaluate(model, val_loader, criterion, device) scheduler.step() print(f" Train loss: {train_loss:.4f} | acc: {train_acc:.3f}") print(f" Val loss: {val_metrics['loss']:.4f} | " f"acc: {val_metrics['accuracy']:.3f} | " f"macro-F1: {val_metrics['macro_f1']:.3f}") history.append({ 'epoch': epoch + 1, 'train_loss': train_loss, 'train_acc': train_acc, 'val_macro_f1': val_metrics['macro_f1'], 'val_loss': val_metrics['loss'], 'val_accuracy': val_metrics['accuracy'], }) # Save best model (always use slug path so load paths in main() match) if val_metrics['macro_f1'] > best_val_f1: best_val_f1 = val_metrics['macro_f1'] best_epoch = epoch + 1 torch.save({ 'model_state_dict': model.state_dict(), 'epoch': epoch + 1, 'val_macro_f1': best_val_f1, 'val_accuracy': val_metrics['accuracy'], 'stage_name': stage_name, 'stage_slug': slug, }, checkpoint_path) print(f" * New best! Saved to {checkpoint_path}") print(f"\n {stage_name} complete. Best: epoch {best_epoch}, macro-F1: {best_val_f1:.3f}") return history, best_val_f1 # ========================== # MAIN # ========================== def _torch_load(path): try: return torch.load(path, weights_only=False) except TypeError: return torch.load(path) def _save_stage_history_json(output_dir: Path, stage_key: str, history: list) -> None: """Write one JSON file per training stage (loss / val metrics per epoch).""" path = output_dir / f'history_{stage_key}.json' with open(path, 'w') as f: json.dump(history, f, indent=2, default=str) print(f" Stage history saved: {path}") def _plot_stage_history(output_dir: Path, stage_key: str, history: list, experiment: str) -> None: """Save train loss + val macro-F1 curves for a single stage.""" if not history: return try: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt epochs = [h['epoch'] for h in history] train_loss = [h['train_loss'] for h in history] val_f1 = [h['val_macro_f1'] for h in history] fig, axes = plt.subplots(1, 2, figsize=(12, 4)) axes[0].plot(epochs, train_loss, 'b-') axes[0].set_xlabel('Epoch') axes[0].set_ylabel('Train loss') axes[0].set_title(f'{stage_key} — train loss') axes[1].plot(epochs, val_f1, 'g-') axes[1].set_xlabel('Epoch') axes[1].set_ylabel('Val macro-F1') axes[1].set_title(f'{stage_key} — validation') fig.suptitle(f'{experiment} / {stage_key}') plt.tight_layout() out_path = output_dir / f'training_history_{stage_key}.png' plt.savefig(out_path, dpi=150) plt.close() print(f" Stage plot saved: {out_path}") except Exception as e: print(f" (Skipping stage plot for {stage_key}: {e})") def _save_stage_artifacts(output_dir: Path, stage_key: str, history: list, experiment: str) -> None: _save_stage_history_json(output_dir, stage_key, history) _plot_stage_history(output_dir, stage_key, history, experiment) def main(): parser = argparse.ArgumentParser(description="Fine-tune DINO ViT-S") parser.add_argument( "--data_dir", type=str, required=True, help="Path to processed data (e.g., ./Data/output/whole_page)", ) parser.add_argument( "--experiment", type=str, required=True, choices=["whole_page", "patches_color", "patches_clahe"], help="Which experiment variant", ) parser.add_argument("--output_dir", type=str, default="./results", help="Where to save checkpoints and results") parser.add_argument("--batch_size", type=int, default=32, help="Batch size (reduce if OOM)") parser.add_argument("--epochs_a", type=int, default=20, help="Epochs for Stage A (head only)") parser.add_argument("--epochs_b", type=int, default=10, help="Epochs for Stage B (last 2 blocks)") parser.add_argument("--epochs_c", type=int, default=10, help="Epochs for Stage C (last 4 blocks)") parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--no_amp", action="store_true", help="Disable mixed precision") parser.add_argument("--skip_stage_c", action="store_true", help="Skip Stage C (last 4 blocks)") parser.add_argument( "--exclude_manifest", type=str, default="./benchmark_page_ids.json", help="Optional class->page_ids JSON; excluded pages are skipped during split creation", ) args = parser.parse_args() stage_a_name = "Stage A: Head only" stage_b_name = "Stage B: Last 2 blocks" stage_c_name = "Stage C: Last 4 blocks" set_seed(SEED) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') output_dir = Path(args.output_dir) / args.experiment output_dir.mkdir(parents=True, exist_ok=True) print(f"\n{'='*60}") print(f" DINOv3 ViT-S Fine-Tuning") print(f" {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"{'='*60}") print(f" Experiment: {args.experiment}") print(f" Data dir: {args.data_dir}") print(f" Device: {device}") print(f" Batch size: {args.batch_size}") print(f" AMP: {not args.no_amp}") print(f" Exclusions: {args.exclude_manifest}") # Page level split print(f"\n Creating page level split") excluded_pages_by_label = load_exclusion_manifest(args.exclude_manifest) excluded_label_count = len(excluded_pages_by_label) excluded_id_count = sum(len(v) for v in excluded_pages_by_label.values()) if excluded_label_count: print(f" Loaded exclusions: {excluded_label_count} labels, {excluded_id_count} page IDs") splits, label_to_idx, idx_to_label, skipped_by_label = create_page_level( args.data_dir, excluded_pages_by_label=excluded_pages_by_label, ) num_classes = len(label_to_idx) print(f" Classes: {num_classes}") print(f" Train: {len(splits['train'])} | Val: {len(splits['val'])} | Test: {len(splits['test'])}") if skipped_by_label: print("\n Skipped excluded files by class:") for label, count in sorted(skipped_by_label.items()): print(f" {label:<20s} {count:>6d}") # Print per-class split counts for split_name in ['train', 'val', 'test']: counts = Counter(label for _, label in splits[split_name]) print(f"\n {split_name}:") for label in sorted(counts.keys()): print(f" {label:<20s} {counts[label]:>6d}") # Save splits for reproducibility splits_info = { split_name: [(fp, label) for fp, label in samples] for split_name, samples in splits.items() } with open(output_dir / 'splits.json', 'w') as f: json.dump({ 'label_to_idx': label_to_idx, 'idx_to_label': {str(k): v for k, v in idx_to_label.items()}, 'split_counts': { name: dict(Counter(l for _, l in samples)) for name, samples in splits.items() }, 'exclude_manifest': str(args.exclude_manifest), 'excluded_label_count': excluded_label_count, 'excluded_page_id_count': excluded_id_count, 'skipped_excluded_files_by_class': dict(skipped_by_label), }, f, indent=2) print(f"Loading DINOv3 processor: {DINOV3_MODEL_ID}") processor = AutoImageProcessor.from_pretrained(DINOV3_MODEL_ID) train_dataset = ScriptDataset(splits['train'], label_to_idx, processor, augment=True) val_dataset = ScriptDataset(splits['val'], label_to_idx, processor, augment=False) test_dataset = ScriptDataset(splits['test'], label_to_idx, processor, augment=False) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=(device.type == 'cuda'), ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=(device.type == 'cuda'), ) test_loader = DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=(device.type == 'cuda'), ) print(f"\n Building DINOv3 classifier ({num_classes} classes)...") model = DINOv3Classifier(DINOV3_MODEL_ID, num_classes, dropout=0.1) model = model.to(device) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" Total params: {total_params:,}") print(f" Trainable params: {trainable_params:,} (head only)") class_weights = get_class_weights(splits['train'], label_to_idx, device) criterion = nn.CrossEntropyLoss(weight=class_weights) use_amp = not args.no_amp and device.type == 'cuda' all_history = {} # Stage A: Head only (backbone frozen) model.freeze_backbone() history_a, best_f1_a = run_stage( model, train_loader, val_loader, criterion, device, stage_name=stage_a_name, lr_backbone=0, lr_head=1e-3, epochs=args.epochs_a, output_dir=output_dir, idx_to_label=idx_to_label, use_amp=use_amp, ) all_history['stage_a'] = history_a _save_stage_artifacts(output_dir, 'stage_a', history_a, args.experiment) ckpt_a = output_dir / f"best_{_stage_checkpoint_slug(stage_a_name)}.pt" best_a = _torch_load(ckpt_a) model.load_state_dict(best_a['model_state_dict']) model.unfreeze_last_n_blocks(2) history_b, best_f1_b = run_stage( model, train_loader, val_loader, criterion, device, stage_name=stage_b_name, lr_backbone=1e-5, lr_head=1e-3, epochs=args.epochs_b, output_dir=output_dir, idx_to_label=idx_to_label, use_amp=use_amp, ) all_history['stage_b'] = history_b _save_stage_artifacts(output_dir, 'stage_b', history_b, args.experiment) if not args.skip_stage_c: ckpt_b = output_dir / f"best_{_stage_checkpoint_slug(stage_b_name)}.pt" best_b = _torch_load(ckpt_b) model.load_state_dict(best_b['model_state_dict']) model.unfreeze_last_n_blocks(4) history_c, best_f1_c = run_stage( model, train_loader, val_loader, criterion, device, stage_name=stage_c_name, lr_backbone=5e-6, lr_head=5e-4, epochs=args.epochs_c, output_dir=output_dir, idx_to_label=idx_to_label, use_amp=use_amp, ) all_history['stage_c'] = history_c _save_stage_artifacts(output_dir, 'stage_c', history_c, args.experiment) # Final evaluation on test set print(f"\n{'='*60}") print(f" FINAL TEST EVALUATION") print(f"{'='*60}") best_checkpoints = list(output_dir.glob('best_*.pt')) best_f1 = 0.0 best_ckpt = None for ckpt_path in best_checkpoints: ckpt = _torch_load(ckpt_path) if ckpt.get('val_macro_f1', 0) > best_f1: best_f1 = ckpt['val_macro_f1'] best_ckpt = ckpt_path if best_ckpt is None: raise RuntimeError("No checkpoint found under output_dir; cannot run test evaluation.") print(f" Loading best checkpoint: {best_ckpt} (val F1: {best_f1:.3f})") model.load_state_dict(_torch_load(best_ckpt)['model_state_dict']) test_metrics, test_preds, test_labels, test_probs = evaluate( model, test_loader, criterion, device, idx_to_label ) page_eval = evaluate_page_level( splits['test'], test_probs, label_to_idx=label_to_idx, idx_to_label=idx_to_label, ) page_metrics = page_eval["metrics"] # Canonical weights for this experiment (same as loaded best val checkpoint, after test eval) final_model_path = output_dir / 'final_model.pt' torch.save( { 'model_state_dict': model.state_dict(), 'experiment': args.experiment, 'model_id': DINOV3_MODEL_ID, 'num_classes': num_classes, 'label_to_idx': label_to_idx, 'source_val_checkpoint': str(best_ckpt), 'val_macro_f1_at_selection': float(best_f1), 'test_metrics': test_metrics, 'page_test_metrics': page_metrics, }, final_model_path, ) print(f"\n Final model (for deployment / comparison) saved: {final_model_path}") print(f"\n Test accuracy: {test_metrics['accuracy']:.3f}") print(f" Test macro-F1: {test_metrics['macro_f1']:.3f}") print(f" Test weighted-F1: {test_metrics['weighted_f1']:.3f}") print(f" Page accuracy: {page_metrics['accuracy']:.3f} " f"| Page macro-F1: {page_metrics['macro_f1']:.3f} " f"| Pages: {page_metrics['num_pages']}") #Classification report target_names = [idx_to_label[i] for i in range(num_classes)] report = classification_report( test_labels, test_preds, target_names=target_names, zero_division=0 ) print(f"\n{report}") # Confusion matrix cm = confusion_matrix(test_labels, test_preds) page_cm = confusion_matrix(page_eval["page_true"], page_eval["page_pred"]) # Save everything results = { 'experiment': args.experiment, 'model': DINOV3_MODEL_ID, 'num_classes': num_classes, 'best_val_checkpoint': str(best_ckpt), 'val_macro_f1_at_selection': float(best_f1), 'final_model_path': str(final_model_path), 'test_metrics': test_metrics, 'page_test_metrics': page_metrics, 'history': all_history, 'confusion_matrix': cm.tolist(), 'page_confusion_matrix': page_cm.tolist(), 'label_to_idx': label_to_idx, 'classification_report': report, 'page_classification_report': classification_report( page_eval["page_true"], page_eval["page_pred"], target_names=target_names, zero_division=0 ), } with open(output_dir / 'results.json', 'w') as f: json.dump(results, f, indent=2, default=str) # Save confusion matrix as CSV import pandas as pd cm_df = pd.DataFrame(cm, index=target_names, columns=target_names) cm_df.to_csv(output_dir / 'confusion_matrix.csv') page_cm_df = pd.DataFrame(page_cm, index=target_names, columns=target_names) page_cm_df.to_csv(output_dir / 'page_confusion_matrix.csv') # Plot confusion matrix try: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import seaborn as sns fig, ax = plt.subplots(figsize=(14, 12)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_names, yticklabels=target_names, ax=ax) ax.set_xlabel('Predicted label') ax.set_ylabel('True label') ax.set_title(f'Confusion Matrix — {args.experiment} (macro-F1: {test_metrics["macro_f1"]:.3f})') plt.tight_layout() plt.savefig(output_dir / 'confusion_matrix.png', dpi=150) plt.close() print(f"\n Confusion matrix saved: {output_dir / 'confusion_matrix.png'}") except ImportError: print(" (matplotlib/seaborn not available, skipping plot)") # Plot page-level confusion matrix try: fig, ax = plt.subplots(figsize=(14, 12)) sns.heatmap(page_cm, annot=True, fmt='d', cmap='Greens', xticklabels=target_names, yticklabels=target_names, ax=ax) ax.set_xlabel('Predicted label') ax.set_ylabel('True label') ax.set_title( f'Page Confusion Matrix — {args.experiment} ' f'(macro-F1: {page_metrics["macro_f1"]:.3f})' ) plt.tight_layout() plt.savefig(output_dir / 'page_confusion_matrix.png', dpi=150) plt.close() print(f" Page confusion matrix saved: {output_dir / 'page_confusion_matrix.png'}") except Exception: pass # Plot training history try: fig, axes = plt.subplots(1, 2, figsize=(14, 5)) all_epochs = [] all_train_loss = [] all_val_f1 = [] offset = 0 for stage_name, stage_history in all_history.items(): for entry in stage_history: all_epochs.append(entry['epoch'] + offset) all_train_loss.append(entry['train_loss']) all_val_f1.append(entry['val_macro_f1']) offset += len(stage_history) axes[0].plot(all_epochs, all_train_loss, 'b-') axes[0].set_xlabel('Epoch') axes[0].set_ylabel('Train Loss') axes[0].set_title('Training Loss') axes[1].plot(all_epochs, all_val_f1, 'g-') axes[1].set_xlabel('Epoch') axes[1].set_ylabel('Macro F1') axes[1].set_title('Validation Macro-F1') plt.suptitle(f'{args.experiment} — Progressive Fine-Tuning') plt.tight_layout() plt.savefig(output_dir / 'training_history.png', dpi=150) plt.close() print(f" Training history saved: {output_dir / 'training_history.png'}") except Exception: pass print(f"\n{'='*60}") print(f" All results saved to: {output_dir}") print(f"{'='*60}\n") if __name__ == "__main__": main()