tibetan-script-classifier / finetune_dinov3.py
karma689's picture
Upload folder using huggingface_hub
a1d2f62 verified
"""
Dino V3 finetunning for script classification
==============================================
Progressive finetuning with page-level train/val/test split
Runs on three preprocessed variants:
- whole page /
- patches color /
- patches_clahe/
Usage:
#Exp1: whole page
python finetune_dinov3.py --data_dir ./Data/output/whole_page --experiment whole_page
#Exp2: patches_color
python finetune_dinov3.py --data_dir ./Data/output/patches_color --experiment patches_color
#Exp3: CLAHE_patches
python finetune_dinov3.py --data_dir ./Data/output/patches_clahe --experiment patches_clahe
Outputs (under --output_dir/<experiment>/):
best_<stage_slug>.pt — best val macro-F1 per stage
history_stage_{a,b,c}.json — per-epoch metrics per stage
training_history_stage_{a,b,c}.png — curves per stage
final_model.pt — weights chosen by best val across stages + test_metrics metadata
results.json, confusion_matrix.*, training_history.png (full run)
Requirements:
pip install torch torchvision transformers scikit-learn matplotlib seaborn
# DINOv3 requires transformers >= 4.56.0
# If not available: pip install --upgrade git+https://github.com/huggingface/transformers.git
"""
import os
import re
import json
import argparse
import random
from pathlib import Path
from collections import Counter, defaultdict
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
from sklearn.metrics import (classification_report, confusion_matrix, f1_score, accuracy_score )
try:
from transformers import AutoImageProcessor, AutoModel
except ImportError:
raise ImportError("transformers >= 4.56.0 required for DINOv3.\n"
"Install: pip install --upgrade git+https://github.com/huggingface/transformers.git"
)
# =====================
# CONFIG
# =====================
DINOV3_MODEL_ID = "facebook/dinov3-vits16-pretrain-lvd1689m"
EMBEDDING_DIM = 384
VALID_EXT = {'.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp', '.webp'}
SEED = 42
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ======================
# Page level spliting
# ======================
def get_page_name(filepath):
"""
Extract the original page name from a patch filename.
e.g., 'manuscript001_p3.png' → 'manuscript001'
e.g., 'manuscript001.png' → 'manuscript001'
This ensures all patches from the same page stay in the same split.
"""
stem = Path(filepath).stem
page_name = re.sub(r'_p\d+$','',stem)
return page_name
def normalize_label_key(label: str) -> str:
"""Normalize class names for manifest lookup."""
return re.sub(r'[^a-z0-9]+', '_', label.lower()).strip('_')
def load_exclusion_manifest(manifest_path: str):
"""
Load class->page_ids exclusions from JSON.
Returns a dict keyed by normalized class labels.
"""
if not manifest_path:
return {}
path = Path(manifest_path)
if not path.is_file():
print(f" Exclusion manifest not found, skipping exclusions: {path}")
return {}
with open(path, "r", encoding="utf-8") as f:
raw = json.load(f)
if not isinstance(raw, dict):
raise ValueError(f"Exclusion manifest must be a JSON object: {path}")
manifest = {}
for label, ids in raw.items():
if not isinstance(ids, list):
continue
norm_label = normalize_label_key(str(label))
manifest[norm_label] = {str(x).strip() for x in ids if str(x).strip()}
return manifest
def create_page_level(data_dir, val_ratio=0.15, test_ratio=0.15, seed=SEED, excluded_pages_by_label=None):
"""
Split at the PAGE level, not the image/patch level.
All patches from one page go into the same split.
Returns:
splits: dict with 'train', 'val', 'test' keys
each value is a list of (filepath, label) tuples
label_to_idx: dict mapping label strings to integers
"""
set_seed(seed)
data_dir = Path(data_dir)
class_pages = defaultdict(lambda: defaultdict(list))
skipped_by_label = Counter()
for cls_dir in sorted(data_dir.iterdir()):
if not cls_dir.is_dir() or cls_dir.name.startswith('.'):
continue
label = cls_dir.name
excluded_pages = set()
if excluded_pages_by_label:
excluded_pages = excluded_pages_by_label.get(normalize_label_key(label), set())
for img_path in sorted(cls_dir.iterdir()):
if img_path.suffix.lower() in VALID_EXT:
page = get_page_name(str(img_path))
if page in excluded_pages:
skipped_by_label[label] += 1
continue
class_pages[label][page].append(str(img_path))
# Create label mapping
labels = sorted(class_pages.keys())
label_to_idx = {label: idx for idx, label in enumerate(labels)}
idx_to_label = {idx: label for label, idx in label_to_idx.items()}
# Split pages per class (stratified)
splits = {'train': [], 'val': [], 'test': []}
for label in labels:
pages = list(class_pages[label].keys())
random.shuffle(pages)
n_pages = len(pages)
n_test = max(1, int(n_pages * test_ratio))
n_val = max(1, int(n_pages * val_ratio))
n_train = n_pages - n_test - n_val
test_pages = pages[:n_test]
val_pages = pages[n_test:n_test + n_val]
train_pages = pages[n_test + n_val:]
for page in train_pages:
for fpath in class_pages[label][page]:
splits['train'].append((fpath, label))
for page in val_pages:
for fpath in class_pages[label][page]:
splits['val'].append((fpath, label))
for page in test_pages:
for fpath in class_pages[label][page]:
splits['test'].append((fpath, label))
return splits, label_to_idx, idx_to_label, dict(skipped_by_label)
class ScriptDataset(Dataset):
def __init__(self, samples, label_to_idx, processor, augment = False):
self.samples = samples
self.label_to_idx = label_to_idx
self.processor = processor
self.augment = augment
#Document aware augmentation
if augment:
self.aug_transform = transforms.Compose([
transforms.RandomRotation(degrees=5, fill=255),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.RandomResizedCrop(224, scale=(0.7, 1.0), ratio=(0.9, 1.1)),
transforms.RandomErasing(p=0.1, scale=(0.02, 0.08)),
])
else:
self.aug_transform = None
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
file_path,label_str = self.samples[idx]
#Load Image
img = Image.open(file_path).convert('RGB')
if self.aug_transform is not None and self.augment:
img = transforms.ToTensor()(img)
img = self.aug_transform(img)
img = transforms.ToPILImage()(img)
# Process with DINOv3 processor (resize, normalize)
inputs = self.processor(images=img, return_tensors="pt")
pixel_values = inputs['pixel_values'].squeeze(0)
label_idx = self.label_to_idx[label_str]
return pixel_values, label_idx
class DINOv3Classifier(nn.Module):
"""
DINOv3 ViT-S backbone + MLP classification head.
The backbone outputs:
- CLS token: 384-dim embedding (used for classification)
- Patch tokens: 196 × 384-dim (not used in this version)
- Register tokens: 4 × 384-dim (not used)
Classification head: 384 → 128 → num_classes
"""
def __init__(self, model_id, num_classes, dropout=0.1):
super().__init__()
#Load pretrained backbone
self.backbone = AutoModel.from_pretrained(model_id)
#Get embedding dim
hidden_size = self.backbone.config.hidden_size
#Classification head
self.head = nn.Sequential(
nn.LayerNorm(hidden_size),
nn.Dropout(dropout),
nn.Linear(hidden_size, 128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, num_classes),
)
self.freeze_backbone()
def freeze_backbone(self):
"""Freeze all the backbone paramenters"""
for params in self.backbone.parameters():
params.requires_grad = False
def unfreeze_last_n_blocks(self, n):
"""
Unfreeze the last N transformer blocks.
DINOv3 ViT-S has 12 blocks (layers).
"""
# First freeze everything
self.freeze_backbone()
# HF DINOv3ViTModel: blocks at backbone.model.layer, final norm at backbone.norm
# (not ViT/BERT-style backbone.encoder.layer).
if hasattr(self.backbone, "model") and hasattr(self.backbone.model, "layer"):
layers = self.backbone.model.layer
elif hasattr(self.backbone, "encoder") and hasattr(self.backbone.encoder, "layer"):
layers = self.backbone.encoder.layer
else:
raise AttributeError(
"Backbone has no recognizable transformer blocks "
"(expected .model.layer for DINOv3 or .encoder.layer for ViT/BERT)."
)
total_layers = len(layers)
for i in range(max(0, total_layers - n), total_layers):
for param in layers[i].parameters():
param.requires_grad = True
if hasattr(self.backbone, "norm"):
for param in self.backbone.norm.parameters():
param.requires_grad = True
elif hasattr(self.backbone, "layernorm"):
for param in self.backbone.layernorm.parameters():
param.requires_grad = True
def forward(self, pixel_values):
# Get backbone outputs
outputs = self.backbone(pixel_values=pixel_values)
# Use CLS token (first token)
cls_embedding = outputs.last_hidden_state[:, 0, :]
# Classify
logits = self.head(cls_embedding)
return logits
# ====================================
# Tranining
# ====================================
def get_class_weights(samples, label_to_idx, device):
"""Compute inverse-frequency class weights for balanced training."""
counts = Counter(label for _, label in samples)
total = sum(counts.values())
weights = torch.zeros(len(label_to_idx), device=device)
for label, idx in label_to_idx.items():
cnt = max(counts.get(label, 1), 1)
weights[idx] = total / (len(label_to_idx) * cnt)
return weights
def get_weighted_sampler(samples, label_to_idx):
"""WeightedRandomSampler for balanced batches."""
counts = Counter(label for _, label in samples)
total = sum(counts.values())
class_weights = {label: total / count for label, count in counts.items()}
sample_weights = [class_weights[label] for _, label in samples]
return WeightedRandomSampler(sample_weights, len(samples), replacement=True)
def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None):
"""Train for one epoch with optional mixed precision."""
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (images,labels) in enumerate(loader):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
if scaler:
with torch.autocast(device_type='cuda', dtype=torch.float16):
logits = model(images)
loss = criterion(logits, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
_, predicted = logits.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
if(batch_idx + 1) % 50 == 0:
print(f" batch {batch_idx+1}/{len(loader)} | "
f"loss: {loss.item():.4f} | acc: {correct/total:.3f}")
return total_loss / total, correct / total
def _stage_checkpoint_slug(stage_name: str) -> str:
"""Stable filename fragment (no spaces/colons) for checkpoint paths."""
s = re.sub(r"[^a-z0-9]+", "_", stage_name.lower())
return re.sub(r"_+", "_", s).strip("_")
@torch.no_grad()
def evaluate(model, loader, criterion, device, idx_to_label=None):
"""Return validation/test metrics and per-sample preds, labels, probs."""
model.eval()
total_loss = 0.0
total = 0
all_preds = []
all_labels = []
all_probs = []
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
bs = images.size(0)
total_loss += loss.item() * bs
total += bs
probs = torch.softmax(logits, dim=1)
pred = logits.argmax(dim=1)
all_preds.extend(pred.cpu().numpy().tolist())
all_labels.extend(labels.cpu().numpy().tolist())
all_probs.extend(probs.cpu().numpy().tolist())
avg_loss = total_loss / max(total, 1)
acc = accuracy_score(all_labels, all_preds)
macro_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
weighted_f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0)
metrics = {
"loss": float(avg_loss),
"accuracy": float(acc),
"macro_f1": float(macro_f1),
"weighted_f1": float(weighted_f1),
}
return metrics, all_preds, all_labels, all_probs
def evaluate_page_level(samples, probs, label_to_idx, idx_to_label):
"""
Aggregate patch-level probabilities to page-level predictions.
Args:
samples: list of (filepath, label_str) for the evaluated split.
probs: list of per-sample probability vectors (same order as samples).
"""
if len(samples) != len(probs):
raise ValueError(
f"samples/probs length mismatch: {len(samples)} != {len(probs)}"
)
page_preds = defaultdict(list)
page_labels = {}
# Page-level true labels from file stems
for filepath, label_str in samples:
page = get_page_name(filepath)
page_labels[page] = label_to_idx[label_str]
# Group probabilities by page
for (filepath, _), p in zip(samples, probs):
page = get_page_name(filepath)
page_preds[page].append(np.asarray(p, dtype=np.float32))
pages_sorted = sorted(page_preds.keys())
all_page_true = []
all_page_pred = []
page_avg_probs = {}
for page in pages_sorted:
avg_probs = np.mean(page_preds[page], axis=0)
pred_idx = int(np.argmax(avg_probs))
true_idx = int(page_labels[page])
all_page_true.append(true_idx)
all_page_pred.append(pred_idx)
page_avg_probs[page] = avg_probs.tolist()
acc = accuracy_score(all_page_true, all_page_pred)
macro_f1 = f1_score(all_page_true, all_page_pred, average="macro", zero_division=0)
weighted_f1 = f1_score(all_page_true, all_page_pred, average="weighted", zero_division=0)
metrics = {
"accuracy": float(acc),
"macro_f1": float(macro_f1),
"weighted_f1": float(weighted_f1),
"num_pages": int(len(pages_sorted)),
"num_samples": int(len(samples)),
}
return {
"metrics": metrics,
"pages": pages_sorted,
"page_true": all_page_true,
"page_pred": all_page_pred,
"page_avg_probs": page_avg_probs,
}
#============================
# Progressive fine-tunning
#============================
def run_stage(model, train_loader, val_loader, criterion, device, stage_name, lr_backbone, lr_head, epochs, output_dir, idx_to_label, use_amp=True):
"""Run one stage of progressive fine-tuning."""
print(f"\n{'='*60}")
print(f" {stage_name}")
print(f"{'='*60}")
# Set up optimizer with different LRs for backbone and head
param_groups = []
backbone_params = [p for p in model.backbone.parameters() if p.requires_grad]
head_params = list(model.head.parameters())
if backbone_params:
param_groups.append({'params': backbone_params, 'lr': lr_backbone})
print(f" Backbone params (trainable): {sum(p.numel() for p in backbone_params):,}")
param_groups.append({'params': head_params, 'lr': lr_head})
print(f" Head params: {sum(p.numel() for p in head_params):,}")
print(f" LR backbone: {lr_backbone}, LR head: {lr_head}")
print(f" Epochs: {epochs}")
optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scaler = torch.amp.GradScaler() if use_amp and device.type == 'cuda' else None
slug = _stage_checkpoint_slug(stage_name)
checkpoint_path = output_dir / f'best_{slug}.pt'
best_val_f1 = 0
best_epoch = 0
history = []
for epoch in range(epochs):
print(f"\n Epoch {epoch+1}/{epochs}")
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device, scaler
)
val_metrics, _, _, _ = evaluate(model, val_loader, criterion, device)
scheduler.step()
print(f" Train loss: {train_loss:.4f} | acc: {train_acc:.3f}")
print(f" Val loss: {val_metrics['loss']:.4f} | "
f"acc: {val_metrics['accuracy']:.3f} | "
f"macro-F1: {val_metrics['macro_f1']:.3f}")
history.append({
'epoch': epoch + 1,
'train_loss': train_loss,
'train_acc': train_acc,
'val_macro_f1': val_metrics['macro_f1'],
'val_loss': val_metrics['loss'],
'val_accuracy': val_metrics['accuracy'],
})
# Save best model (always use slug path so load paths in main() match)
if val_metrics['macro_f1'] > best_val_f1:
best_val_f1 = val_metrics['macro_f1']
best_epoch = epoch + 1
torch.save({
'model_state_dict': model.state_dict(),
'epoch': epoch + 1,
'val_macro_f1': best_val_f1,
'val_accuracy': val_metrics['accuracy'],
'stage_name': stage_name,
'stage_slug': slug,
}, checkpoint_path)
print(f" * New best! Saved to {checkpoint_path}")
print(f"\n {stage_name} complete. Best: epoch {best_epoch}, macro-F1: {best_val_f1:.3f}")
return history, best_val_f1
# ==========================
# MAIN
# ==========================
def _torch_load(path):
try:
return torch.load(path, weights_only=False)
except TypeError:
return torch.load(path)
def _save_stage_history_json(output_dir: Path, stage_key: str, history: list) -> None:
"""Write one JSON file per training stage (loss / val metrics per epoch)."""
path = output_dir / f'history_{stage_key}.json'
with open(path, 'w') as f:
json.dump(history, f, indent=2, default=str)
print(f" Stage history saved: {path}")
def _plot_stage_history(output_dir: Path, stage_key: str, history: list, experiment: str) -> None:
"""Save train loss + val macro-F1 curves for a single stage."""
if not history:
return
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
epochs = [h['epoch'] for h in history]
train_loss = [h['train_loss'] for h in history]
val_f1 = [h['val_macro_f1'] for h in history]
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(epochs, train_loss, 'b-')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Train loss')
axes[0].set_title(f'{stage_key} — train loss')
axes[1].plot(epochs, val_f1, 'g-')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Val macro-F1')
axes[1].set_title(f'{stage_key} — validation')
fig.suptitle(f'{experiment} / {stage_key}')
plt.tight_layout()
out_path = output_dir / f'training_history_{stage_key}.png'
plt.savefig(out_path, dpi=150)
plt.close()
print(f" Stage plot saved: {out_path}")
except Exception as e:
print(f" (Skipping stage plot for {stage_key}: {e})")
def _save_stage_artifacts(output_dir: Path, stage_key: str, history: list, experiment: str) -> None:
_save_stage_history_json(output_dir, stage_key, history)
_plot_stage_history(output_dir, stage_key, history, experiment)
def main():
parser = argparse.ArgumentParser(description="Fine-tune DINO ViT-S")
parser.add_argument(
"--data_dir", type=str, required=True,
help="Path to processed data (e.g., ./Data/output/whole_page)",
)
parser.add_argument(
"--experiment", type=str, required=True,
choices=["whole_page", "patches_color", "patches_clahe"],
help="Which experiment variant",
)
parser.add_argument("--output_dir", type=str, default="./results",
help="Where to save checkpoints and results")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size (reduce if OOM)")
parser.add_argument("--epochs_a", type=int, default=20,
help="Epochs for Stage A (head only)")
parser.add_argument("--epochs_b", type=int, default=10,
help="Epochs for Stage B (last 2 blocks)")
parser.add_argument("--epochs_c", type=int, default=10,
help="Epochs for Stage C (last 4 blocks)")
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--no_amp", action="store_true",
help="Disable mixed precision")
parser.add_argument("--skip_stage_c", action="store_true",
help="Skip Stage C (last 4 blocks)")
parser.add_argument(
"--exclude_manifest",
type=str,
default="./benchmark_page_ids.json",
help="Optional class->page_ids JSON; excluded pages are skipped during split creation",
)
args = parser.parse_args()
stage_a_name = "Stage A: Head only"
stage_b_name = "Stage B: Last 2 blocks"
stage_c_name = "Stage C: Last 4 blocks"
set_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output_dir = Path(args.output_dir) / args.experiment
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n{'='*60}")
print(f" DINOv3 ViT-S Fine-Tuning")
print(f" {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"{'='*60}")
print(f" Experiment: {args.experiment}")
print(f" Data dir: {args.data_dir}")
print(f" Device: {device}")
print(f" Batch size: {args.batch_size}")
print(f" AMP: {not args.no_amp}")
print(f" Exclusions: {args.exclude_manifest}")
# Page level split
print(f"\n Creating page level split")
excluded_pages_by_label = load_exclusion_manifest(args.exclude_manifest)
excluded_label_count = len(excluded_pages_by_label)
excluded_id_count = sum(len(v) for v in excluded_pages_by_label.values())
if excluded_label_count:
print(f" Loaded exclusions: {excluded_label_count} labels, {excluded_id_count} page IDs")
splits, label_to_idx, idx_to_label, skipped_by_label = create_page_level(
args.data_dir,
excluded_pages_by_label=excluded_pages_by_label,
)
num_classes = len(label_to_idx)
print(f" Classes: {num_classes}")
print(f" Train: {len(splits['train'])} | Val: {len(splits['val'])} | Test: {len(splits['test'])}")
if skipped_by_label:
print("\n Skipped excluded files by class:")
for label, count in sorted(skipped_by_label.items()):
print(f" {label:<20s} {count:>6d}")
# Print per-class split counts
for split_name in ['train', 'val', 'test']:
counts = Counter(label for _, label in splits[split_name])
print(f"\n {split_name}:")
for label in sorted(counts.keys()):
print(f" {label:<20s} {counts[label]:>6d}")
# Save splits for reproducibility
splits_info = {
split_name: [(fp, label) for fp, label in samples]
for split_name, samples in splits.items()
}
with open(output_dir / 'splits.json', 'w') as f:
json.dump({
'label_to_idx': label_to_idx,
'idx_to_label': {str(k): v for k, v in idx_to_label.items()},
'split_counts': {
name: dict(Counter(l for _, l in samples))
for name, samples in splits.items()
},
'exclude_manifest': str(args.exclude_manifest),
'excluded_label_count': excluded_label_count,
'excluded_page_id_count': excluded_id_count,
'skipped_excluded_files_by_class': dict(skipped_by_label),
}, f, indent=2)
print(f"Loading DINOv3 processor: {DINOV3_MODEL_ID}")
processor = AutoImageProcessor.from_pretrained(DINOV3_MODEL_ID)
train_dataset = ScriptDataset(splits['train'], label_to_idx, processor, augment=True)
val_dataset = ScriptDataset(splits['val'], label_to_idx, processor, augment=False)
test_dataset = ScriptDataset(splits['test'], label_to_idx, processor, augment=False)
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=(device.type == 'cuda'),
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=(device.type == 'cuda'),
)
test_loader = DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=(device.type == 'cuda'),
)
print(f"\n Building DINOv3 classifier ({num_classes} classes)...")
model = DINOv3Classifier(DINOV3_MODEL_ID, num_classes, dropout=0.1)
model = model.to(device)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" Total params: {total_params:,}")
print(f" Trainable params: {trainable_params:,} (head only)")
class_weights = get_class_weights(splits['train'], label_to_idx, device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
use_amp = not args.no_amp and device.type == 'cuda'
all_history = {}
# Stage A: Head only (backbone frozen)
model.freeze_backbone()
history_a, best_f1_a = run_stage(
model, train_loader, val_loader, criterion, device,
stage_name=stage_a_name,
lr_backbone=0, lr_head=1e-3,
epochs=args.epochs_a, output_dir=output_dir,
idx_to_label=idx_to_label, use_amp=use_amp,
)
all_history['stage_a'] = history_a
_save_stage_artifacts(output_dir, 'stage_a', history_a, args.experiment)
ckpt_a = output_dir / f"best_{_stage_checkpoint_slug(stage_a_name)}.pt"
best_a = _torch_load(ckpt_a)
model.load_state_dict(best_a['model_state_dict'])
model.unfreeze_last_n_blocks(2)
history_b, best_f1_b = run_stage(
model, train_loader, val_loader, criterion, device,
stage_name=stage_b_name,
lr_backbone=1e-5, lr_head=1e-3,
epochs=args.epochs_b, output_dir=output_dir,
idx_to_label=idx_to_label, use_amp=use_amp,
)
all_history['stage_b'] = history_b
_save_stage_artifacts(output_dir, 'stage_b', history_b, args.experiment)
if not args.skip_stage_c:
ckpt_b = output_dir / f"best_{_stage_checkpoint_slug(stage_b_name)}.pt"
best_b = _torch_load(ckpt_b)
model.load_state_dict(best_b['model_state_dict'])
model.unfreeze_last_n_blocks(4)
history_c, best_f1_c = run_stage(
model, train_loader, val_loader, criterion, device,
stage_name=stage_c_name,
lr_backbone=5e-6, lr_head=5e-4,
epochs=args.epochs_c, output_dir=output_dir,
idx_to_label=idx_to_label, use_amp=use_amp,
)
all_history['stage_c'] = history_c
_save_stage_artifacts(output_dir, 'stage_c', history_c, args.experiment)
# Final evaluation on test set
print(f"\n{'='*60}")
print(f" FINAL TEST EVALUATION")
print(f"{'='*60}")
best_checkpoints = list(output_dir.glob('best_*.pt'))
best_f1 = 0.0
best_ckpt = None
for ckpt_path in best_checkpoints:
ckpt = _torch_load(ckpt_path)
if ckpt.get('val_macro_f1', 0) > best_f1:
best_f1 = ckpt['val_macro_f1']
best_ckpt = ckpt_path
if best_ckpt is None:
raise RuntimeError("No checkpoint found under output_dir; cannot run test evaluation.")
print(f" Loading best checkpoint: {best_ckpt} (val F1: {best_f1:.3f})")
model.load_state_dict(_torch_load(best_ckpt)['model_state_dict'])
test_metrics, test_preds, test_labels, test_probs = evaluate(
model, test_loader, criterion, device, idx_to_label
)
page_eval = evaluate_page_level(
splits['test'],
test_probs,
label_to_idx=label_to_idx,
idx_to_label=idx_to_label,
)
page_metrics = page_eval["metrics"]
# Canonical weights for this experiment (same as loaded best val checkpoint, after test eval)
final_model_path = output_dir / 'final_model.pt'
torch.save(
{
'model_state_dict': model.state_dict(),
'experiment': args.experiment,
'model_id': DINOV3_MODEL_ID,
'num_classes': num_classes,
'label_to_idx': label_to_idx,
'source_val_checkpoint': str(best_ckpt),
'val_macro_f1_at_selection': float(best_f1),
'test_metrics': test_metrics,
'page_test_metrics': page_metrics,
},
final_model_path,
)
print(f"\n Final model (for deployment / comparison) saved: {final_model_path}")
print(f"\n Test accuracy: {test_metrics['accuracy']:.3f}")
print(f" Test macro-F1: {test_metrics['macro_f1']:.3f}")
print(f" Test weighted-F1: {test_metrics['weighted_f1']:.3f}")
print(f" Page accuracy: {page_metrics['accuracy']:.3f} "
f"| Page macro-F1: {page_metrics['macro_f1']:.3f} "
f"| Pages: {page_metrics['num_pages']}")
#Classification report
target_names = [idx_to_label[i] for i in range(num_classes)]
report = classification_report(
test_labels, test_preds, target_names=target_names, zero_division=0
)
print(f"\n{report}")
# Confusion matrix
cm = confusion_matrix(test_labels, test_preds)
page_cm = confusion_matrix(page_eval["page_true"], page_eval["page_pred"])
# Save everything
results = {
'experiment': args.experiment,
'model': DINOV3_MODEL_ID,
'num_classes': num_classes,
'best_val_checkpoint': str(best_ckpt),
'val_macro_f1_at_selection': float(best_f1),
'final_model_path': str(final_model_path),
'test_metrics': test_metrics,
'page_test_metrics': page_metrics,
'history': all_history,
'confusion_matrix': cm.tolist(),
'page_confusion_matrix': page_cm.tolist(),
'label_to_idx': label_to_idx,
'classification_report': report,
'page_classification_report': classification_report(
page_eval["page_true"], page_eval["page_pred"], target_names=target_names, zero_division=0
),
}
with open(output_dir / 'results.json', 'w') as f:
json.dump(results, f, indent=2, default=str)
# Save confusion matrix as CSV
import pandas as pd
cm_df = pd.DataFrame(cm, index=target_names, columns=target_names)
cm_df.to_csv(output_dir / 'confusion_matrix.csv')
page_cm_df = pd.DataFrame(page_cm, index=target_names, columns=target_names)
page_cm_df.to_csv(output_dir / 'page_confusion_matrix.csv')
# Plot confusion matrix
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(figsize=(14, 12))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=target_names, yticklabels=target_names, ax=ax)
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
ax.set_title(f'Confusion Matrix — {args.experiment} (macro-F1: {test_metrics["macro_f1"]:.3f})')
plt.tight_layout()
plt.savefig(output_dir / 'confusion_matrix.png', dpi=150)
plt.close()
print(f"\n Confusion matrix saved: {output_dir / 'confusion_matrix.png'}")
except ImportError:
print(" (matplotlib/seaborn not available, skipping plot)")
# Plot page-level confusion matrix
try:
fig, ax = plt.subplots(figsize=(14, 12))
sns.heatmap(page_cm, annot=True, fmt='d', cmap='Greens',
xticklabels=target_names, yticklabels=target_names, ax=ax)
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
ax.set_title(
f'Page Confusion Matrix — {args.experiment} '
f'(macro-F1: {page_metrics["macro_f1"]:.3f})'
)
plt.tight_layout()
plt.savefig(output_dir / 'page_confusion_matrix.png', dpi=150)
plt.close()
print(f" Page confusion matrix saved: {output_dir / 'page_confusion_matrix.png'}")
except Exception:
pass
# Plot training history
try:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
all_epochs = []
all_train_loss = []
all_val_f1 = []
offset = 0
for stage_name, stage_history in all_history.items():
for entry in stage_history:
all_epochs.append(entry['epoch'] + offset)
all_train_loss.append(entry['train_loss'])
all_val_f1.append(entry['val_macro_f1'])
offset += len(stage_history)
axes[0].plot(all_epochs, all_train_loss, 'b-')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Train Loss')
axes[0].set_title('Training Loss')
axes[1].plot(all_epochs, all_val_f1, 'g-')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Macro F1')
axes[1].set_title('Validation Macro-F1')
plt.suptitle(f'{args.experiment} — Progressive Fine-Tuning')
plt.tight_layout()
plt.savefig(output_dir / 'training_history.png', dpi=150)
plt.close()
print(f" Training history saved: {output_dir / 'training_history.png'}")
except Exception:
pass
print(f"\n{'='*60}")
print(f" All results saved to: {output_dir}")
print(f"{'='*60}\n")
if __name__ == "__main__":
main()