Image Classification
Transformers
Tibetan
dinov3
tibetan
script-classification
paleography
fine-tuned
document-analysis
Eval Results (legacy)
Instructions to use openpecha/tibetan-script-classifier with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use openpecha/tibetan-script-classifier with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-classification", model="openpecha/tibetan-script-classifier") pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("openpecha/tibetan-script-classifier", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Dino V3 finetunning for script classification | |
| ============================================== | |
| Progressive finetuning with page-level train/val/test split | |
| Runs on three preprocessed variants: | |
| - whole page / | |
| - patches color / | |
| - patches_clahe/ | |
| Usage: | |
| #Exp1: whole page | |
| python finetune_dinov3.py --data_dir ./Data/output/whole_page --experiment whole_page | |
| #Exp2: patches_color | |
| python finetune_dinov3.py --data_dir ./Data/output/patches_color --experiment patches_color | |
| #Exp3: CLAHE_patches | |
| python finetune_dinov3.py --data_dir ./Data/output/patches_clahe --experiment patches_clahe | |
| Outputs (under --output_dir/<experiment>/): | |
| best_<stage_slug>.pt — best val macro-F1 per stage | |
| history_stage_{a,b,c}.json — per-epoch metrics per stage | |
| training_history_stage_{a,b,c}.png — curves per stage | |
| final_model.pt — weights chosen by best val across stages + test_metrics metadata | |
| results.json, confusion_matrix.*, training_history.png (full run) | |
| Requirements: | |
| pip install torch torchvision transformers scikit-learn matplotlib seaborn | |
| # DINOv3 requires transformers >= 4.56.0 | |
| # If not available: pip install --upgrade git+https://github.com/huggingface/transformers.git | |
| """ | |
| import os | |
| import re | |
| import json | |
| import argparse | |
| import random | |
| from pathlib import Path | |
| from collections import Counter, defaultdict | |
| from datetime import datetime | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.cuda.amp import GradScaler | |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler | |
| from torchvision import transforms | |
| from PIL import Image | |
| from sklearn.metrics import (classification_report, confusion_matrix, f1_score, accuracy_score ) | |
| try: | |
| from transformers import AutoImageProcessor, AutoModel | |
| except ImportError: | |
| raise ImportError("transformers >= 4.56.0 required for DINOv3.\n" | |
| "Install: pip install --upgrade git+https://github.com/huggingface/transformers.git" | |
| ) | |
| # ===================== | |
| # CONFIG | |
| # ===================== | |
| DINOV3_MODEL_ID = "facebook/dinov3-vits16-pretrain-lvd1689m" | |
| EMBEDDING_DIM = 384 | |
| VALID_EXT = {'.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp', '.webp'} | |
| SEED = 42 | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| # ====================== | |
| # Page level spliting | |
| # ====================== | |
| def get_page_name(filepath): | |
| """ | |
| Extract the original page name from a patch filename. | |
| e.g., 'manuscript001_p3.png' → 'manuscript001' | |
| e.g., 'manuscript001.png' → 'manuscript001' | |
| This ensures all patches from the same page stay in the same split. | |
| """ | |
| stem = Path(filepath).stem | |
| page_name = re.sub(r'_p\d+$','',stem) | |
| return page_name | |
| def normalize_label_key(label: str) -> str: | |
| """Normalize class names for manifest lookup.""" | |
| return re.sub(r'[^a-z0-9]+', '_', label.lower()).strip('_') | |
| def load_exclusion_manifest(manifest_path: str): | |
| """ | |
| Load class->page_ids exclusions from JSON. | |
| Returns a dict keyed by normalized class labels. | |
| """ | |
| if not manifest_path: | |
| return {} | |
| path = Path(manifest_path) | |
| if not path.is_file(): | |
| print(f" Exclusion manifest not found, skipping exclusions: {path}") | |
| return {} | |
| with open(path, "r", encoding="utf-8") as f: | |
| raw = json.load(f) | |
| if not isinstance(raw, dict): | |
| raise ValueError(f"Exclusion manifest must be a JSON object: {path}") | |
| manifest = {} | |
| for label, ids in raw.items(): | |
| if not isinstance(ids, list): | |
| continue | |
| norm_label = normalize_label_key(str(label)) | |
| manifest[norm_label] = {str(x).strip() for x in ids if str(x).strip()} | |
| return manifest | |
| def create_page_level(data_dir, val_ratio=0.15, test_ratio=0.15, seed=SEED, excluded_pages_by_label=None): | |
| """ | |
| Split at the PAGE level, not the image/patch level. | |
| All patches from one page go into the same split. | |
| Returns: | |
| splits: dict with 'train', 'val', 'test' keys | |
| each value is a list of (filepath, label) tuples | |
| label_to_idx: dict mapping label strings to integers | |
| """ | |
| set_seed(seed) | |
| data_dir = Path(data_dir) | |
| class_pages = defaultdict(lambda: defaultdict(list)) | |
| skipped_by_label = Counter() | |
| for cls_dir in sorted(data_dir.iterdir()): | |
| if not cls_dir.is_dir() or cls_dir.name.startswith('.'): | |
| continue | |
| label = cls_dir.name | |
| excluded_pages = set() | |
| if excluded_pages_by_label: | |
| excluded_pages = excluded_pages_by_label.get(normalize_label_key(label), set()) | |
| for img_path in sorted(cls_dir.iterdir()): | |
| if img_path.suffix.lower() in VALID_EXT: | |
| page = get_page_name(str(img_path)) | |
| if page in excluded_pages: | |
| skipped_by_label[label] += 1 | |
| continue | |
| class_pages[label][page].append(str(img_path)) | |
| # Create label mapping | |
| labels = sorted(class_pages.keys()) | |
| label_to_idx = {label: idx for idx, label in enumerate(labels)} | |
| idx_to_label = {idx: label for label, idx in label_to_idx.items()} | |
| # Split pages per class (stratified) | |
| splits = {'train': [], 'val': [], 'test': []} | |
| for label in labels: | |
| pages = list(class_pages[label].keys()) | |
| random.shuffle(pages) | |
| n_pages = len(pages) | |
| n_test = max(1, int(n_pages * test_ratio)) | |
| n_val = max(1, int(n_pages * val_ratio)) | |
| n_train = n_pages - n_test - n_val | |
| test_pages = pages[:n_test] | |
| val_pages = pages[n_test:n_test + n_val] | |
| train_pages = pages[n_test + n_val:] | |
| for page in train_pages: | |
| for fpath in class_pages[label][page]: | |
| splits['train'].append((fpath, label)) | |
| for page in val_pages: | |
| for fpath in class_pages[label][page]: | |
| splits['val'].append((fpath, label)) | |
| for page in test_pages: | |
| for fpath in class_pages[label][page]: | |
| splits['test'].append((fpath, label)) | |
| return splits, label_to_idx, idx_to_label, dict(skipped_by_label) | |
| class ScriptDataset(Dataset): | |
| def __init__(self, samples, label_to_idx, processor, augment = False): | |
| self.samples = samples | |
| self.label_to_idx = label_to_idx | |
| self.processor = processor | |
| self.augment = augment | |
| #Document aware augmentation | |
| if augment: | |
| self.aug_transform = transforms.Compose([ | |
| transforms.RandomRotation(degrees=5, fill=255), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2), | |
| transforms.RandomResizedCrop(224, scale=(0.7, 1.0), ratio=(0.9, 1.1)), | |
| transforms.RandomErasing(p=0.1, scale=(0.02, 0.08)), | |
| ]) | |
| else: | |
| self.aug_transform = None | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| file_path,label_str = self.samples[idx] | |
| #Load Image | |
| img = Image.open(file_path).convert('RGB') | |
| if self.aug_transform is not None and self.augment: | |
| img = transforms.ToTensor()(img) | |
| img = self.aug_transform(img) | |
| img = transforms.ToPILImage()(img) | |
| # Process with DINOv3 processor (resize, normalize) | |
| inputs = self.processor(images=img, return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].squeeze(0) | |
| label_idx = self.label_to_idx[label_str] | |
| return pixel_values, label_idx | |
| class DINOv3Classifier(nn.Module): | |
| """ | |
| DINOv3 ViT-S backbone + MLP classification head. | |
| The backbone outputs: | |
| - CLS token: 384-dim embedding (used for classification) | |
| - Patch tokens: 196 × 384-dim (not used in this version) | |
| - Register tokens: 4 × 384-dim (not used) | |
| Classification head: 384 → 128 → num_classes | |
| """ | |
| def __init__(self, model_id, num_classes, dropout=0.1): | |
| super().__init__() | |
| #Load pretrained backbone | |
| self.backbone = AutoModel.from_pretrained(model_id) | |
| #Get embedding dim | |
| hidden_size = self.backbone.config.hidden_size | |
| #Classification head | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(hidden_size), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size, 128), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, num_classes), | |
| ) | |
| self.freeze_backbone() | |
| def freeze_backbone(self): | |
| """Freeze all the backbone paramenters""" | |
| for params in self.backbone.parameters(): | |
| params.requires_grad = False | |
| def unfreeze_last_n_blocks(self, n): | |
| """ | |
| Unfreeze the last N transformer blocks. | |
| DINOv3 ViT-S has 12 blocks (layers). | |
| """ | |
| # First freeze everything | |
| self.freeze_backbone() | |
| # HF DINOv3ViTModel: blocks at backbone.model.layer, final norm at backbone.norm | |
| # (not ViT/BERT-style backbone.encoder.layer). | |
| if hasattr(self.backbone, "model") and hasattr(self.backbone.model, "layer"): | |
| layers = self.backbone.model.layer | |
| elif hasattr(self.backbone, "encoder") and hasattr(self.backbone.encoder, "layer"): | |
| layers = self.backbone.encoder.layer | |
| else: | |
| raise AttributeError( | |
| "Backbone has no recognizable transformer blocks " | |
| "(expected .model.layer for DINOv3 or .encoder.layer for ViT/BERT)." | |
| ) | |
| total_layers = len(layers) | |
| for i in range(max(0, total_layers - n), total_layers): | |
| for param in layers[i].parameters(): | |
| param.requires_grad = True | |
| if hasattr(self.backbone, "norm"): | |
| for param in self.backbone.norm.parameters(): | |
| param.requires_grad = True | |
| elif hasattr(self.backbone, "layernorm"): | |
| for param in self.backbone.layernorm.parameters(): | |
| param.requires_grad = True | |
| def forward(self, pixel_values): | |
| # Get backbone outputs | |
| outputs = self.backbone(pixel_values=pixel_values) | |
| # Use CLS token (first token) | |
| cls_embedding = outputs.last_hidden_state[:, 0, :] | |
| # Classify | |
| logits = self.head(cls_embedding) | |
| return logits | |
| # ==================================== | |
| # Tranining | |
| # ==================================== | |
| def get_class_weights(samples, label_to_idx, device): | |
| """Compute inverse-frequency class weights for balanced training.""" | |
| counts = Counter(label for _, label in samples) | |
| total = sum(counts.values()) | |
| weights = torch.zeros(len(label_to_idx), device=device) | |
| for label, idx in label_to_idx.items(): | |
| cnt = max(counts.get(label, 1), 1) | |
| weights[idx] = total / (len(label_to_idx) * cnt) | |
| return weights | |
| def get_weighted_sampler(samples, label_to_idx): | |
| """WeightedRandomSampler for balanced batches.""" | |
| counts = Counter(label for _, label in samples) | |
| total = sum(counts.values()) | |
| class_weights = {label: total / count for label, count in counts.items()} | |
| sample_weights = [class_weights[label] for _, label in samples] | |
| return WeightedRandomSampler(sample_weights, len(samples), replacement=True) | |
| def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None): | |
| """Train for one epoch with optional mixed precision.""" | |
| model.train() | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| for batch_idx, (images,labels) in enumerate(loader): | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| optimizer.zero_grad() | |
| if scaler: | |
| with torch.autocast(device_type='cuda', dtype=torch.float16): | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() * images.size(0) | |
| _, predicted = logits.max(1) | |
| correct += predicted.eq(labels).sum().item() | |
| total += labels.size(0) | |
| if(batch_idx + 1) % 50 == 0: | |
| print(f" batch {batch_idx+1}/{len(loader)} | " | |
| f"loss: {loss.item():.4f} | acc: {correct/total:.3f}") | |
| return total_loss / total, correct / total | |
| def _stage_checkpoint_slug(stage_name: str) -> str: | |
| """Stable filename fragment (no spaces/colons) for checkpoint paths.""" | |
| s = re.sub(r"[^a-z0-9]+", "_", stage_name.lower()) | |
| return re.sub(r"_+", "_", s).strip("_") | |
| def evaluate(model, loader, criterion, device, idx_to_label=None): | |
| """Return validation/test metrics and per-sample preds, labels, probs.""" | |
| model.eval() | |
| total_loss = 0.0 | |
| total = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| all_probs = [] | |
| for images, labels in loader: | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| bs = images.size(0) | |
| total_loss += loss.item() * bs | |
| total += bs | |
| probs = torch.softmax(logits, dim=1) | |
| pred = logits.argmax(dim=1) | |
| all_preds.extend(pred.cpu().numpy().tolist()) | |
| all_labels.extend(labels.cpu().numpy().tolist()) | |
| all_probs.extend(probs.cpu().numpy().tolist()) | |
| avg_loss = total_loss / max(total, 1) | |
| acc = accuracy_score(all_labels, all_preds) | |
| macro_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0) | |
| weighted_f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0) | |
| metrics = { | |
| "loss": float(avg_loss), | |
| "accuracy": float(acc), | |
| "macro_f1": float(macro_f1), | |
| "weighted_f1": float(weighted_f1), | |
| } | |
| return metrics, all_preds, all_labels, all_probs | |
| def evaluate_page_level(samples, probs, label_to_idx, idx_to_label): | |
| """ | |
| Aggregate patch-level probabilities to page-level predictions. | |
| Args: | |
| samples: list of (filepath, label_str) for the evaluated split. | |
| probs: list of per-sample probability vectors (same order as samples). | |
| """ | |
| if len(samples) != len(probs): | |
| raise ValueError( | |
| f"samples/probs length mismatch: {len(samples)} != {len(probs)}" | |
| ) | |
| page_preds = defaultdict(list) | |
| page_labels = {} | |
| # Page-level true labels from file stems | |
| for filepath, label_str in samples: | |
| page = get_page_name(filepath) | |
| page_labels[page] = label_to_idx[label_str] | |
| # Group probabilities by page | |
| for (filepath, _), p in zip(samples, probs): | |
| page = get_page_name(filepath) | |
| page_preds[page].append(np.asarray(p, dtype=np.float32)) | |
| pages_sorted = sorted(page_preds.keys()) | |
| all_page_true = [] | |
| all_page_pred = [] | |
| page_avg_probs = {} | |
| for page in pages_sorted: | |
| avg_probs = np.mean(page_preds[page], axis=0) | |
| pred_idx = int(np.argmax(avg_probs)) | |
| true_idx = int(page_labels[page]) | |
| all_page_true.append(true_idx) | |
| all_page_pred.append(pred_idx) | |
| page_avg_probs[page] = avg_probs.tolist() | |
| acc = accuracy_score(all_page_true, all_page_pred) | |
| macro_f1 = f1_score(all_page_true, all_page_pred, average="macro", zero_division=0) | |
| weighted_f1 = f1_score(all_page_true, all_page_pred, average="weighted", zero_division=0) | |
| metrics = { | |
| "accuracy": float(acc), | |
| "macro_f1": float(macro_f1), | |
| "weighted_f1": float(weighted_f1), | |
| "num_pages": int(len(pages_sorted)), | |
| "num_samples": int(len(samples)), | |
| } | |
| return { | |
| "metrics": metrics, | |
| "pages": pages_sorted, | |
| "page_true": all_page_true, | |
| "page_pred": all_page_pred, | |
| "page_avg_probs": page_avg_probs, | |
| } | |
| #============================ | |
| # Progressive fine-tunning | |
| #============================ | |
| def run_stage(model, train_loader, val_loader, criterion, device, stage_name, lr_backbone, lr_head, epochs, output_dir, idx_to_label, use_amp=True): | |
| """Run one stage of progressive fine-tuning.""" | |
| print(f"\n{'='*60}") | |
| print(f" {stage_name}") | |
| print(f"{'='*60}") | |
| # Set up optimizer with different LRs for backbone and head | |
| param_groups = [] | |
| backbone_params = [p for p in model.backbone.parameters() if p.requires_grad] | |
| head_params = list(model.head.parameters()) | |
| if backbone_params: | |
| param_groups.append({'params': backbone_params, 'lr': lr_backbone}) | |
| print(f" Backbone params (trainable): {sum(p.numel() for p in backbone_params):,}") | |
| param_groups.append({'params': head_params, 'lr': lr_head}) | |
| print(f" Head params: {sum(p.numel() for p in head_params):,}") | |
| print(f" LR backbone: {lr_backbone}, LR head: {lr_head}") | |
| print(f" Epochs: {epochs}") | |
| optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) | |
| scaler = torch.amp.GradScaler() if use_amp and device.type == 'cuda' else None | |
| slug = _stage_checkpoint_slug(stage_name) | |
| checkpoint_path = output_dir / f'best_{slug}.pt' | |
| best_val_f1 = 0 | |
| best_epoch = 0 | |
| history = [] | |
| for epoch in range(epochs): | |
| print(f"\n Epoch {epoch+1}/{epochs}") | |
| train_loss, train_acc = train_one_epoch( | |
| model, train_loader, criterion, optimizer, device, scaler | |
| ) | |
| val_metrics, _, _, _ = evaluate(model, val_loader, criterion, device) | |
| scheduler.step() | |
| print(f" Train loss: {train_loss:.4f} | acc: {train_acc:.3f}") | |
| print(f" Val loss: {val_metrics['loss']:.4f} | " | |
| f"acc: {val_metrics['accuracy']:.3f} | " | |
| f"macro-F1: {val_metrics['macro_f1']:.3f}") | |
| history.append({ | |
| 'epoch': epoch + 1, | |
| 'train_loss': train_loss, | |
| 'train_acc': train_acc, | |
| 'val_macro_f1': val_metrics['macro_f1'], | |
| 'val_loss': val_metrics['loss'], | |
| 'val_accuracy': val_metrics['accuracy'], | |
| }) | |
| # Save best model (always use slug path so load paths in main() match) | |
| if val_metrics['macro_f1'] > best_val_f1: | |
| best_val_f1 = val_metrics['macro_f1'] | |
| best_epoch = epoch + 1 | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'epoch': epoch + 1, | |
| 'val_macro_f1': best_val_f1, | |
| 'val_accuracy': val_metrics['accuracy'], | |
| 'stage_name': stage_name, | |
| 'stage_slug': slug, | |
| }, checkpoint_path) | |
| print(f" * New best! Saved to {checkpoint_path}") | |
| print(f"\n {stage_name} complete. Best: epoch {best_epoch}, macro-F1: {best_val_f1:.3f}") | |
| return history, best_val_f1 | |
| # ========================== | |
| # MAIN | |
| # ========================== | |
| def _torch_load(path): | |
| try: | |
| return torch.load(path, weights_only=False) | |
| except TypeError: | |
| return torch.load(path) | |
| def _save_stage_history_json(output_dir: Path, stage_key: str, history: list) -> None: | |
| """Write one JSON file per training stage (loss / val metrics per epoch).""" | |
| path = output_dir / f'history_{stage_key}.json' | |
| with open(path, 'w') as f: | |
| json.dump(history, f, indent=2, default=str) | |
| print(f" Stage history saved: {path}") | |
| def _plot_stage_history(output_dir: Path, stage_key: str, history: list, experiment: str) -> None: | |
| """Save train loss + val macro-F1 curves for a single stage.""" | |
| if not history: | |
| return | |
| try: | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| epochs = [h['epoch'] for h in history] | |
| train_loss = [h['train_loss'] for h in history] | |
| val_f1 = [h['val_macro_f1'] for h in history] | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) | |
| axes[0].plot(epochs, train_loss, 'b-') | |
| axes[0].set_xlabel('Epoch') | |
| axes[0].set_ylabel('Train loss') | |
| axes[0].set_title(f'{stage_key} — train loss') | |
| axes[1].plot(epochs, val_f1, 'g-') | |
| axes[1].set_xlabel('Epoch') | |
| axes[1].set_ylabel('Val macro-F1') | |
| axes[1].set_title(f'{stage_key} — validation') | |
| fig.suptitle(f'{experiment} / {stage_key}') | |
| plt.tight_layout() | |
| out_path = output_dir / f'training_history_{stage_key}.png' | |
| plt.savefig(out_path, dpi=150) | |
| plt.close() | |
| print(f" Stage plot saved: {out_path}") | |
| except Exception as e: | |
| print(f" (Skipping stage plot for {stage_key}: {e})") | |
| def _save_stage_artifacts(output_dir: Path, stage_key: str, history: list, experiment: str) -> None: | |
| _save_stage_history_json(output_dir, stage_key, history) | |
| _plot_stage_history(output_dir, stage_key, history, experiment) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Fine-tune DINO ViT-S") | |
| parser.add_argument( | |
| "--data_dir", type=str, required=True, | |
| help="Path to processed data (e.g., ./Data/output/whole_page)", | |
| ) | |
| parser.add_argument( | |
| "--experiment", type=str, required=True, | |
| choices=["whole_page", "patches_color", "patches_clahe"], | |
| help="Which experiment variant", | |
| ) | |
| parser.add_argument("--output_dir", type=str, default="./results", | |
| help="Where to save checkpoints and results") | |
| parser.add_argument("--batch_size", type=int, default=32, | |
| help="Batch size (reduce if OOM)") | |
| parser.add_argument("--epochs_a", type=int, default=20, | |
| help="Epochs for Stage A (head only)") | |
| parser.add_argument("--epochs_b", type=int, default=10, | |
| help="Epochs for Stage B (last 2 blocks)") | |
| parser.add_argument("--epochs_c", type=int, default=10, | |
| help="Epochs for Stage C (last 4 blocks)") | |
| parser.add_argument("--num_workers", type=int, default=4) | |
| parser.add_argument("--no_amp", action="store_true", | |
| help="Disable mixed precision") | |
| parser.add_argument("--skip_stage_c", action="store_true", | |
| help="Skip Stage C (last 4 blocks)") | |
| parser.add_argument( | |
| "--exclude_manifest", | |
| type=str, | |
| default="./benchmark_page_ids.json", | |
| help="Optional class->page_ids JSON; excluded pages are skipped during split creation", | |
| ) | |
| args = parser.parse_args() | |
| stage_a_name = "Stage A: Head only" | |
| stage_b_name = "Stage B: Last 2 blocks" | |
| stage_c_name = "Stage C: Last 4 blocks" | |
| set_seed(SEED) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| output_dir = Path(args.output_dir) / args.experiment | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| print(f"\n{'='*60}") | |
| print(f" DINOv3 ViT-S Fine-Tuning") | |
| print(f" {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print(f"{'='*60}") | |
| print(f" Experiment: {args.experiment}") | |
| print(f" Data dir: {args.data_dir}") | |
| print(f" Device: {device}") | |
| print(f" Batch size: {args.batch_size}") | |
| print(f" AMP: {not args.no_amp}") | |
| print(f" Exclusions: {args.exclude_manifest}") | |
| # Page level split | |
| print(f"\n Creating page level split") | |
| excluded_pages_by_label = load_exclusion_manifest(args.exclude_manifest) | |
| excluded_label_count = len(excluded_pages_by_label) | |
| excluded_id_count = sum(len(v) for v in excluded_pages_by_label.values()) | |
| if excluded_label_count: | |
| print(f" Loaded exclusions: {excluded_label_count} labels, {excluded_id_count} page IDs") | |
| splits, label_to_idx, idx_to_label, skipped_by_label = create_page_level( | |
| args.data_dir, | |
| excluded_pages_by_label=excluded_pages_by_label, | |
| ) | |
| num_classes = len(label_to_idx) | |
| print(f" Classes: {num_classes}") | |
| print(f" Train: {len(splits['train'])} | Val: {len(splits['val'])} | Test: {len(splits['test'])}") | |
| if skipped_by_label: | |
| print("\n Skipped excluded files by class:") | |
| for label, count in sorted(skipped_by_label.items()): | |
| print(f" {label:<20s} {count:>6d}") | |
| # Print per-class split counts | |
| for split_name in ['train', 'val', 'test']: | |
| counts = Counter(label for _, label in splits[split_name]) | |
| print(f"\n {split_name}:") | |
| for label in sorted(counts.keys()): | |
| print(f" {label:<20s} {counts[label]:>6d}") | |
| # Save splits for reproducibility | |
| splits_info = { | |
| split_name: [(fp, label) for fp, label in samples] | |
| for split_name, samples in splits.items() | |
| } | |
| with open(output_dir / 'splits.json', 'w') as f: | |
| json.dump({ | |
| 'label_to_idx': label_to_idx, | |
| 'idx_to_label': {str(k): v for k, v in idx_to_label.items()}, | |
| 'split_counts': { | |
| name: dict(Counter(l for _, l in samples)) | |
| for name, samples in splits.items() | |
| }, | |
| 'exclude_manifest': str(args.exclude_manifest), | |
| 'excluded_label_count': excluded_label_count, | |
| 'excluded_page_id_count': excluded_id_count, | |
| 'skipped_excluded_files_by_class': dict(skipped_by_label), | |
| }, f, indent=2) | |
| print(f"Loading DINOv3 processor: {DINOV3_MODEL_ID}") | |
| processor = AutoImageProcessor.from_pretrained(DINOV3_MODEL_ID) | |
| train_dataset = ScriptDataset(splits['train'], label_to_idx, processor, augment=True) | |
| val_dataset = ScriptDataset(splits['val'], label_to_idx, processor, augment=False) | |
| test_dataset = ScriptDataset(splits['test'], label_to_idx, processor, augment=False) | |
| train_loader = DataLoader( | |
| train_dataset, batch_size=args.batch_size, shuffle=True, | |
| num_workers=args.num_workers, pin_memory=(device.type == 'cuda'), | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, batch_size=args.batch_size, shuffle=False, | |
| num_workers=args.num_workers, pin_memory=(device.type == 'cuda'), | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, batch_size=args.batch_size, shuffle=False, | |
| num_workers=args.num_workers, pin_memory=(device.type == 'cuda'), | |
| ) | |
| print(f"\n Building DINOv3 classifier ({num_classes} classes)...") | |
| model = DINOv3Classifier(DINOV3_MODEL_ID, num_classes, dropout=0.1) | |
| model = model.to(device) | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f" Total params: {total_params:,}") | |
| print(f" Trainable params: {trainable_params:,} (head only)") | |
| class_weights = get_class_weights(splits['train'], label_to_idx, device) | |
| criterion = nn.CrossEntropyLoss(weight=class_weights) | |
| use_amp = not args.no_amp and device.type == 'cuda' | |
| all_history = {} | |
| # Stage A: Head only (backbone frozen) | |
| model.freeze_backbone() | |
| history_a, best_f1_a = run_stage( | |
| model, train_loader, val_loader, criterion, device, | |
| stage_name=stage_a_name, | |
| lr_backbone=0, lr_head=1e-3, | |
| epochs=args.epochs_a, output_dir=output_dir, | |
| idx_to_label=idx_to_label, use_amp=use_amp, | |
| ) | |
| all_history['stage_a'] = history_a | |
| _save_stage_artifacts(output_dir, 'stage_a', history_a, args.experiment) | |
| ckpt_a = output_dir / f"best_{_stage_checkpoint_slug(stage_a_name)}.pt" | |
| best_a = _torch_load(ckpt_a) | |
| model.load_state_dict(best_a['model_state_dict']) | |
| model.unfreeze_last_n_blocks(2) | |
| history_b, best_f1_b = run_stage( | |
| model, train_loader, val_loader, criterion, device, | |
| stage_name=stage_b_name, | |
| lr_backbone=1e-5, lr_head=1e-3, | |
| epochs=args.epochs_b, output_dir=output_dir, | |
| idx_to_label=idx_to_label, use_amp=use_amp, | |
| ) | |
| all_history['stage_b'] = history_b | |
| _save_stage_artifacts(output_dir, 'stage_b', history_b, args.experiment) | |
| if not args.skip_stage_c: | |
| ckpt_b = output_dir / f"best_{_stage_checkpoint_slug(stage_b_name)}.pt" | |
| best_b = _torch_load(ckpt_b) | |
| model.load_state_dict(best_b['model_state_dict']) | |
| model.unfreeze_last_n_blocks(4) | |
| history_c, best_f1_c = run_stage( | |
| model, train_loader, val_loader, criterion, device, | |
| stage_name=stage_c_name, | |
| lr_backbone=5e-6, lr_head=5e-4, | |
| epochs=args.epochs_c, output_dir=output_dir, | |
| idx_to_label=idx_to_label, use_amp=use_amp, | |
| ) | |
| all_history['stage_c'] = history_c | |
| _save_stage_artifacts(output_dir, 'stage_c', history_c, args.experiment) | |
| # Final evaluation on test set | |
| print(f"\n{'='*60}") | |
| print(f" FINAL TEST EVALUATION") | |
| print(f"{'='*60}") | |
| best_checkpoints = list(output_dir.glob('best_*.pt')) | |
| best_f1 = 0.0 | |
| best_ckpt = None | |
| for ckpt_path in best_checkpoints: | |
| ckpt = _torch_load(ckpt_path) | |
| if ckpt.get('val_macro_f1', 0) > best_f1: | |
| best_f1 = ckpt['val_macro_f1'] | |
| best_ckpt = ckpt_path | |
| if best_ckpt is None: | |
| raise RuntimeError("No checkpoint found under output_dir; cannot run test evaluation.") | |
| print(f" Loading best checkpoint: {best_ckpt} (val F1: {best_f1:.3f})") | |
| model.load_state_dict(_torch_load(best_ckpt)['model_state_dict']) | |
| test_metrics, test_preds, test_labels, test_probs = evaluate( | |
| model, test_loader, criterion, device, idx_to_label | |
| ) | |
| page_eval = evaluate_page_level( | |
| splits['test'], | |
| test_probs, | |
| label_to_idx=label_to_idx, | |
| idx_to_label=idx_to_label, | |
| ) | |
| page_metrics = page_eval["metrics"] | |
| # Canonical weights for this experiment (same as loaded best val checkpoint, after test eval) | |
| final_model_path = output_dir / 'final_model.pt' | |
| torch.save( | |
| { | |
| 'model_state_dict': model.state_dict(), | |
| 'experiment': args.experiment, | |
| 'model_id': DINOV3_MODEL_ID, | |
| 'num_classes': num_classes, | |
| 'label_to_idx': label_to_idx, | |
| 'source_val_checkpoint': str(best_ckpt), | |
| 'val_macro_f1_at_selection': float(best_f1), | |
| 'test_metrics': test_metrics, | |
| 'page_test_metrics': page_metrics, | |
| }, | |
| final_model_path, | |
| ) | |
| print(f"\n Final model (for deployment / comparison) saved: {final_model_path}") | |
| print(f"\n Test accuracy: {test_metrics['accuracy']:.3f}") | |
| print(f" Test macro-F1: {test_metrics['macro_f1']:.3f}") | |
| print(f" Test weighted-F1: {test_metrics['weighted_f1']:.3f}") | |
| print(f" Page accuracy: {page_metrics['accuracy']:.3f} " | |
| f"| Page macro-F1: {page_metrics['macro_f1']:.3f} " | |
| f"| Pages: {page_metrics['num_pages']}") | |
| #Classification report | |
| target_names = [idx_to_label[i] for i in range(num_classes)] | |
| report = classification_report( | |
| test_labels, test_preds, target_names=target_names, zero_division=0 | |
| ) | |
| print(f"\n{report}") | |
| # Confusion matrix | |
| cm = confusion_matrix(test_labels, test_preds) | |
| page_cm = confusion_matrix(page_eval["page_true"], page_eval["page_pred"]) | |
| # Save everything | |
| results = { | |
| 'experiment': args.experiment, | |
| 'model': DINOV3_MODEL_ID, | |
| 'num_classes': num_classes, | |
| 'best_val_checkpoint': str(best_ckpt), | |
| 'val_macro_f1_at_selection': float(best_f1), | |
| 'final_model_path': str(final_model_path), | |
| 'test_metrics': test_metrics, | |
| 'page_test_metrics': page_metrics, | |
| 'history': all_history, | |
| 'confusion_matrix': cm.tolist(), | |
| 'page_confusion_matrix': page_cm.tolist(), | |
| 'label_to_idx': label_to_idx, | |
| 'classification_report': report, | |
| 'page_classification_report': classification_report( | |
| page_eval["page_true"], page_eval["page_pred"], target_names=target_names, zero_division=0 | |
| ), | |
| } | |
| with open(output_dir / 'results.json', 'w') as f: | |
| json.dump(results, f, indent=2, default=str) | |
| # Save confusion matrix as CSV | |
| import pandas as pd | |
| cm_df = pd.DataFrame(cm, index=target_names, columns=target_names) | |
| cm_df.to_csv(output_dir / 'confusion_matrix.csv') | |
| page_cm_df = pd.DataFrame(page_cm, index=target_names, columns=target_names) | |
| page_cm_df.to_csv(output_dir / 'page_confusion_matrix.csv') | |
| # Plot confusion matrix | |
| try: | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| fig, ax = plt.subplots(figsize=(14, 12)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=target_names, yticklabels=target_names, ax=ax) | |
| ax.set_xlabel('Predicted label') | |
| ax.set_ylabel('True label') | |
| ax.set_title(f'Confusion Matrix — {args.experiment} (macro-F1: {test_metrics["macro_f1"]:.3f})') | |
| plt.tight_layout() | |
| plt.savefig(output_dir / 'confusion_matrix.png', dpi=150) | |
| plt.close() | |
| print(f"\n Confusion matrix saved: {output_dir / 'confusion_matrix.png'}") | |
| except ImportError: | |
| print(" (matplotlib/seaborn not available, skipping plot)") | |
| # Plot page-level confusion matrix | |
| try: | |
| fig, ax = plt.subplots(figsize=(14, 12)) | |
| sns.heatmap(page_cm, annot=True, fmt='d', cmap='Greens', | |
| xticklabels=target_names, yticklabels=target_names, ax=ax) | |
| ax.set_xlabel('Predicted label') | |
| ax.set_ylabel('True label') | |
| ax.set_title( | |
| f'Page Confusion Matrix — {args.experiment} ' | |
| f'(macro-F1: {page_metrics["macro_f1"]:.3f})' | |
| ) | |
| plt.tight_layout() | |
| plt.savefig(output_dir / 'page_confusion_matrix.png', dpi=150) | |
| plt.close() | |
| print(f" Page confusion matrix saved: {output_dir / 'page_confusion_matrix.png'}") | |
| except Exception: | |
| pass | |
| # Plot training history | |
| try: | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| all_epochs = [] | |
| all_train_loss = [] | |
| all_val_f1 = [] | |
| offset = 0 | |
| for stage_name, stage_history in all_history.items(): | |
| for entry in stage_history: | |
| all_epochs.append(entry['epoch'] + offset) | |
| all_train_loss.append(entry['train_loss']) | |
| all_val_f1.append(entry['val_macro_f1']) | |
| offset += len(stage_history) | |
| axes[0].plot(all_epochs, all_train_loss, 'b-') | |
| axes[0].set_xlabel('Epoch') | |
| axes[0].set_ylabel('Train Loss') | |
| axes[0].set_title('Training Loss') | |
| axes[1].plot(all_epochs, all_val_f1, 'g-') | |
| axes[1].set_xlabel('Epoch') | |
| axes[1].set_ylabel('Macro F1') | |
| axes[1].set_title('Validation Macro-F1') | |
| plt.suptitle(f'{args.experiment} — Progressive Fine-Tuning') | |
| plt.tight_layout() | |
| plt.savefig(output_dir / 'training_history.png', dpi=150) | |
| plt.close() | |
| print(f" Training history saved: {output_dir / 'training_history.png'}") | |
| except Exception: | |
| pass | |
| print(f"\n{'='*60}") | |
| print(f" All results saved to: {output_dir}") | |
| print(f"{'='*60}\n") | |
| if __name__ == "__main__": | |
| main() | |