#!/usr/bin/env python3 """ Train KYC Document Rotation Classifier - CPU Only ================================================= This script trains a lightweight MobileNetV3-Small classifier to detect document rotation: 0°, 90°, 180°, 270°. Requirements: pip install torch torchvision pillow numpy huggingface_hub tqdm Usage: python train_rotation_classifier.py Dataset: Jwalit/moire-docs (will download automatically) """ import os import json import random import warnings from pathlib import Path import numpy as np from PIL import Image import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms, models from torchvision.transforms import functional as TF from huggingface_hub import hf_hub_download, HfApi, create_repo from tqdm import tqdm warnings.filterwarnings("ignore") # ── Configuration ───────────────────────── DATASET_REPO = "Jwalit/moire-docs" LOCAL_DIR = Path("./moire-docs") BATCH_SIZE = 16 EPOCHS = 15 LR = 1e-4 IMG_SIZE = 224 DEVICE = torch.device("cpu") MAX_IMAGES = 1500 SEED = 42 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) # ── Download Dataset ────────────────────── def download_dataset(): """Download images from Jwalit/moire-docs.""" LOCAL_DIR.mkdir(parents=True, exist_ok=True) api = HfApi() files = api.list_repo_files(DATASET_REPO, repo_type="dataset") image_files = [f for f in files if f.lower().endswith(('.jpg', '.jpeg', '.png'))] image_files = [f for f in image_files if '.ipynb' not in f] random.shuffle(image_files) image_files = image_files[:MAX_IMAGES] print(f"Downloading {len(image_files)} images...") for rel_path in tqdm(image_files, desc="Download"): try: hf_hub_download( repo_id=DATASET_REPO, filename=rel_path, repo_type="dataset", local_dir=LOCAL_DIR, ) except Exception: pass # Collect downloaded images exts = ('.jpg', '.jpeg', '.png') imgs = [p for e in exts for p in LOCAL_DIR.rglob(f'*{e}')] return [p for p in imgs if '.ipynb' not in str(p)] # ── Dataset ─────────────────────────────── class RotationDataset(Dataset): """Self-supervised rotation dataset. Each image × 4 rotations.""" ANGLES = [0, 90, 180, 270] def __init__(self, paths, img_size=IMG_SIZE): self.paths = paths self.transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def __len__(self): return len(self.paths) * 4 def __getitem__(self, idx): path = self.paths[idx // 4] angle_idx = idx % 4 img = Image.open(path).convert('RGB') img = TF.rotate(img, self.ANGLES[angle_idx]) return self.transform(img), angle_idx # ── Model ───────────────────────────────── class RotationModel(nn.Module): """MobileNetV3-Small for 4-class rotation classification.""" def __init__(self): super().__init__() self.backbone = models.mobilenet_v3_small( weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1) in_features = self.backbone.classifier[3].in_features self.backbone.classifier[3] = nn.Linear(in_features, 4) def forward(self, x): return self.backbone(x) # ── Training ────────────────────────────── def train(model, train_loader, val_loader): model.to(DEVICE) optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) criterion = nn.CrossEntropyLoss() best_acc = 0.0 best_state = None for epoch in range(EPOCHS): # Train model.train() train_loss = 0.0 for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} train", leave=False): images, labels = images.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() scheduler.step() # Validate model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(DEVICE), labels.to(DEVICE) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs, 1) correct += (predicted == labels).sum().item() total += labels.size(0) train_loss /= len(train_loader) val_loss /= len(val_loader) val_acc = correct / total if total > 0 else 0 print(f"Epoch {epoch+1}/{EPOCHS}: " f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} val_acc={val_acc:.4f}") if val_acc > best_acc: best_acc = val_acc best_state = {k: v.clone() for k, v in model.state_dict().items()} if best_state: model.load_state_dict(best_state) return model, best_acc # ── Push to Hub ─────────────────────────── def push_model(model, accuracy): output_dir = Path("./outputs") output_dir.mkdir(exist_ok=True) torch.save(model.state_dict(), output_dir / "rotation_model.bin") with open(output_dir / "config.json", "w") as f: json.dump({ "task": "rotation_classification", "backbone": "mobilenet_v3_small", "num_classes": 4, "classes": ["0", "90", "180", "270"], "epochs": EPOCHS, "accuracy": accuracy, }, f, indent=2) repo_name = "Jwalit/kyc-document-rotation-classifier" try: create_repo(repo_name, repo_type="model", exist_ok=True) api = HfApi() api.upload_folder(folder_path=str(output_dir), repo_id=repo_name, repo_type="model") print(f"\nPushed to https://huggingface.co/{repo_name}") except Exception as e: print(f"\nPush error: {e}") print(f"Model saved locally to: {output_dir}/rotation_model.bin") # ── Main ────────────────────────────────── def main(): print("=" * 60) print("KYC Document Rotation Classifier - CPU Training") print("=" * 60) # Download dataset print("\n[1/4] Downloading dataset...") images = download_dataset() print(f"Total images: {len(images)}") if len(images) < 20: print("ERROR: Not enough images downloaded!") return # Split random.shuffle(images) n_train = int(0.85 * len(images)) train_images = images[:n_train] val_images = images[n_train:] print(f"Train: {len(train_images)}, Val: {len(val_images)}") # DataLoaders print("\n[2/4] Creating datasets...") train_dataset = RotationDataset(train_images) val_dataset = RotationDataset(val_images) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE) print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}") # Train print("\n[3/4] Training model...") model = RotationModel() model, best_acc = train(model, train_loader, val_loader) print(f"\nBest validation accuracy: {best_acc:.2%}") # Push print("\n[4/4] Pushing to Hugging Face Hub...") push_model(model, best_acc) print("\n" + "=" * 60) print("Training complete!") print("=" * 60) if __name__ == "__main__": main()