Spaces:
Running
Running
File size: 5,824 Bytes
17f1739 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | #!/usr/bin/env python3
"""
Script to train Donut model on synthetic certificate data
"""
import argparse
import sys
from pathlib import Path
import json
from PIL import Image
# Add app to path
sys.path.append(str(Path(__file__).parent.parent))
# def create_donut_dataset(synthetic_dir: str, output_dir: str):
# """Create dataset in Donut format from synthetic data"""
# synthetic_path = Path(synthetic_dir)
# output_path = Path(output_dir)
# output_path.mkdir(parents=True, exist_ok=True)
# images_dir = synthetic_path / "images"
# labels_dir = synthetic_path / "labels"
# dataset = []
# # Get all image files
# image_files = list(images_dir.glob("*.png"))
# print(f"Found {len(image_files)} images")
# for i, img_path in enumerate(image_files[:10000]): # Limit for training
# # Get corresponding label
# label_path = labels_dir / f"{img_path.stem}.json"
# if not label_path.exists():
# print(f"Warning: No label for {img_path.name}")
# continue
# # Load image
# image = Image.open(img_path).convert("RGB")
# # Load label
# import gzip
# try:
# # Try reading as gzip
# with gzip.open(label_path, 'rt', encoding='utf-8') as f:
# label_data = json.load(f)
# except OSError:
# # Fallback: plain JSON
# with open(label_path, 'r', encoding='utf-8') as f:
# label_data = json.load(f)
# # Create Donut format sample
# sample = {
# "image": image,
# "label": {
# "name": label_data.get("name", ""),
# "student_id": label_data.get("student_id", label_data.get("certificate_id", "")),
# "university": label_data.get("university", label_data.get("organization", "")),
# "course": label_data.get("course", ""),
# "gpa": str(label_data.get("gpa", "")),
# "issue_date": label_data.get("issue_date", ""),
# "language": label_data.get("language", "english")
# }
# }
# dataset.append(sample)
# if (i + 1) % 100 == 0:
# print(f"Processed {i + 1}/{len(image_files[:10000])} samples")
# return dataset
def create_donut_dataset(synthetic_dir: str, output_dir: str):
"""Create dataset with lazy loading (store paths, not images)"""
synthetic_path = Path(synthetic_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
images_dir = synthetic_path / "images"
labels_dir = synthetic_path / "labels"
dataset = []
image_files = list(images_dir.glob("*.png"))
print(f"Found {len(image_files)} images")
for i, img_path in enumerate(image_files): # remove the 10000 limit
label_path = labels_dir / f"{img_path.stem}.json"
if not label_path.exists():
print(f"Warning: No label for {img_path.name}")
continue
# Store only paths and label paths
dataset.append({
"image_path": str(img_path),
"label_path": str(label_path)
})
if (i + 1) % 100 == 0:
print(f"Processed {i + 1}/{len(image_files)} samples")
return dataset
def split_dataset(dataset, train_ratio=0.8):
"""Split dataset into train and validation"""
split_idx = int(len(dataset) * train_ratio)
train_dataset = dataset[:split_idx]
val_dataset = dataset[split_idx:]
return train_dataset, val_dataset
def main():
parser = argparse.ArgumentParser(description="Train Donut model on synthetic data")
parser.add_argument("--synthetic-dir", type=str, default="data/training/synthetic",
help="Directory with synthetic data")
parser.add_argument("--output-dir", type=str, default="models/donut_certificate",
help="Output directory for trained model")
parser.add_argument("--epochs", type=int, default=10,
help="Number of training epochs")
parser.add_argument("--batch-size", type=int, default=2,
help="Batch size for training")
args = parser.parse_args()
print("Creating Donut dataset...")
#dataset = create_donut_dataset(args.synthetic_dir, args.output_dir)
dataset = create_donut_dataset(args.synthetic_dir, args.output_dir)
for i, item in enumerate(dataset):
if "image_path" not in item or "label_path" not in item:
print(f"Malformed dataset item at index {i}: {item}")
sys.exit(1)
if len(dataset) < 100:
print(f"Warning: Only {len(dataset)} samples available. Need at least 100 for training.")
sys.exit(1)
print(f"Created dataset with {len(dataset)} samples")
# Split dataset
train_dataset, val_dataset = split_dataset(dataset)
print(f"Train: {len(train_dataset)} samples, Validation: {len(val_dataset)} samples")
# Initialize Donut model
from app.analyzers.ml_models.donut_model import DonutCertificateParser
print("Initializing Donut model...")
donut_parser = DonutCertificateParser()
# Train model
print(f"Starting training for {args.epochs} epochs...")
try:
train_result = donut_parser.fine_tune(
train_dataset=train_dataset,
val_dataset=val_dataset,
output_dir=args.output_dir
)
print(f"\n✅ Training completed successfully!")
print(f" Model saved to: {args.output_dir}")
except Exception as e:
print(f"\n❌ Training failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
|