#!/usr/bin/env python3 """ Script to train Donut model on synthetic certificate data """ import argparse import sys from pathlib import Path import json from PIL import Image # Add app to path sys.path.append(str(Path(__file__).parent.parent)) # def create_donut_dataset(synthetic_dir: str, output_dir: str): # """Create dataset in Donut format from synthetic data""" # synthetic_path = Path(synthetic_dir) # output_path = Path(output_dir) # output_path.mkdir(parents=True, exist_ok=True) # images_dir = synthetic_path / "images" # labels_dir = synthetic_path / "labels" # dataset = [] # # Get all image files # image_files = list(images_dir.glob("*.png")) # print(f"Found {len(image_files)} images") # for i, img_path in enumerate(image_files[:10000]): # Limit for training # # Get corresponding label # label_path = labels_dir / f"{img_path.stem}.json" # if not label_path.exists(): # print(f"Warning: No label for {img_path.name}") # continue # # Load image # image = Image.open(img_path).convert("RGB") # # Load label # import gzip # try: # # Try reading as gzip # with gzip.open(label_path, 'rt', encoding='utf-8') as f: # label_data = json.load(f) # except OSError: # # Fallback: plain JSON # with open(label_path, 'r', encoding='utf-8') as f: # label_data = json.load(f) # # Create Donut format sample # sample = { # "image": image, # "label": { # "name": label_data.get("name", ""), # "student_id": label_data.get("student_id", label_data.get("certificate_id", "")), # "university": label_data.get("university", label_data.get("organization", "")), # "course": label_data.get("course", ""), # "gpa": str(label_data.get("gpa", "")), # "issue_date": label_data.get("issue_date", ""), # "language": label_data.get("language", "english") # } # } # dataset.append(sample) # if (i + 1) % 100 == 0: # print(f"Processed {i + 1}/{len(image_files[:10000])} samples") # return dataset def create_donut_dataset(synthetic_dir: str, output_dir: str): """Create dataset with lazy loading (store paths, not images)""" synthetic_path = Path(synthetic_dir) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) images_dir = synthetic_path / "images" labels_dir = synthetic_path / "labels" dataset = [] image_files = list(images_dir.glob("*.png")) print(f"Found {len(image_files)} images") for i, img_path in enumerate(image_files): # remove the 10000 limit label_path = labels_dir / f"{img_path.stem}.json" if not label_path.exists(): print(f"Warning: No label for {img_path.name}") continue # Store only paths and label paths dataset.append({ "image_path": str(img_path), "label_path": str(label_path) }) if (i + 1) % 100 == 0: print(f"Processed {i + 1}/{len(image_files)} samples") return dataset def split_dataset(dataset, train_ratio=0.8): """Split dataset into train and validation""" split_idx = int(len(dataset) * train_ratio) train_dataset = dataset[:split_idx] val_dataset = dataset[split_idx:] return train_dataset, val_dataset def main(): parser = argparse.ArgumentParser(description="Train Donut model on synthetic data") parser.add_argument("--synthetic-dir", type=str, default="data/training/synthetic", help="Directory with synthetic data") parser.add_argument("--output-dir", type=str, default="models/donut_certificate", help="Output directory for trained model") parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs") parser.add_argument("--batch-size", type=int, default=2, help="Batch size for training") args = parser.parse_args() print("Creating Donut dataset...") #dataset = create_donut_dataset(args.synthetic_dir, args.output_dir) dataset = create_donut_dataset(args.synthetic_dir, args.output_dir) for i, item in enumerate(dataset): if "image_path" not in item or "label_path" not in item: print(f"Malformed dataset item at index {i}: {item}") sys.exit(1) if len(dataset) < 100: print(f"Warning: Only {len(dataset)} samples available. Need at least 100 for training.") sys.exit(1) print(f"Created dataset with {len(dataset)} samples") # Split dataset train_dataset, val_dataset = split_dataset(dataset) print(f"Train: {len(train_dataset)} samples, Validation: {len(val_dataset)} samples") # Initialize Donut model from app.analyzers.ml_models.donut_model import DonutCertificateParser print("Initializing Donut model...") donut_parser = DonutCertificateParser() # Train model print(f"Starting training for {args.epochs} epochs...") try: train_result = donut_parser.fine_tune( train_dataset=train_dataset, val_dataset=val_dataset, output_dir=args.output_dir ) print(f"\n✅ Training completed successfully!") print(f" Model saved to: {args.output_dir}") except Exception as e: print(f"\n❌ Training failed: {e}") sys.exit(1) if __name__ == "__main__": main()