""" DeepSeek OCR Fine-tuning for Sanskrit - Simplified Version Works with transformers 4.45.0, peft, accelerate """ import os import csv import torch import torchvision.transforms as T from glob import glob from pathlib import Path from PIL import Image, ImageOps from io import BytesIO from dataclasses import dataclass from typing import Any, Dict, List from peft import LoraConfig, get_peft_model from transformers import AutoModel, AutoProcessor, Trainer, TrainingArguments from datasets import Dataset, DatasetDict import argparse os.environ["TOKENIZERS_PARALLELISM"] = "false" def load_dataset_local(dataset_path, train_size=0.8, val_size=0.1, max_samples=None): """Load dataset from local path""" print(f"Loading dataset from: {dataset_path}") labels_csv = os.path.join(dataset_path, "LABELS", "labels.csv") labels_dict = {} with open(labels_csv, 'r', encoding='utf-8') as f: reader = csv.reader(f) header = next(reader) for row in reader: if row: labels_dict[row[0]] = row[1] print(f"Loaded {len(labels_dict)} labels") image_paths = sorted(glob(os.path.join(dataset_path, "IMAGES", "*.jpg"))) print(f"Found {len(image_paths)} images") data = [] for img_path in image_paths: img_name = Path(img_path).name if img_name in labels_dict: text = labels_dict[img_name].strip() if text: data.append({"image_path": img_path, "text": text}) print(f"Paired {len(data)} samples") if max_samples and max_samples < len(data): data = data[:max_samples] dataset = Dataset.from_list(data) # Split train_test = dataset.train_test_split(test_size=(1 - train_size), seed=42) val_test_ratio = val_size / (1 - train_size) val_test = train_test['test'].train_test_split(test_size=(1 - val_test_ratio), seed=42) return DatasetDict({ 'train': train_test['train'], 'validation': val_test['train'], 'test': val_test['test'] }) class ImageTransform: """Image transform for normalization.""" def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): self.mean = mean self.std = std self.transform = T.Compose([ T.ToTensor(), T.Normalize(mean=mean, std=std) ]) def __call__(self, image): return self.transform(image).float() @dataclass class DeepSeekOCRDataCollator: """Custom data collator for DeepSeek-OCR training""" tokenizer: Any image_size: int = 640 base_size: int = 1024 prompt: str = "\nFree OCR. " def __post_init__(self): self.image_transform = ImageTransform() self.image_token_id = 128815 self.patch_size = 16 self.downsample_ratio = 4 def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: from torch.nn.utils.rnn import pad_sequence import math batch_input_ids = [] batch_labels = [] batch_images = [] batch_images_seq_mask = [] batch_images_spatial_crop = [] for feature in features: image_path = feature["image_path"] text = feature["text"] # Load and process image image = Image.open(image_path).convert("RGB") # Create global view global_view = ImageOps.pad( image, (self.base_size, self.base_size), color=(128, 128, 128) ) image_tensor = self.image_transform(global_view) # Create empty patches tensor (no local crops for simplicity) empty_patches = torch.zeros(1, 3, self.image_size, self.image_size) # Build prompt full_text = f"<|User|>{self.prompt}<|Assistant|>{text}" # Tokenize tokens = self.tokenizer.encode(full_text, add_special_tokens=False) # Calculate image token positions num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) num_image_tokens = (num_queries + 1) * num_queries + 1 # Build input_ids with image tokens input_ids = [0] # BOS images_seq_mask = [False] # Add image tokens input_ids.extend([self.image_token_id] * num_image_tokens) images_seq_mask.extend([True] * num_image_tokens) # Add text tokens input_ids.extend(tokens) images_seq_mask.extend([False] * len(tokens)) # Add EOS input_ids.append(1) images_seq_mask.append(False) batch_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) batch_labels.append(torch.tensor(input_ids, dtype=torch.long)) # Model expects (patches, original) tuple batch_images.append((empty_patches, image_tensor.unsqueeze(0))) batch_images_seq_mask.append(torch.tensor(images_seq_mask, dtype=torch.bool)) # Spatial crop shape: (height_crops, width_crops) batch_images_spatial_crop.append(torch.tensor([1, 1], dtype=torch.long)) # Pad sequences input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0) labels = pad_sequence(batch_labels, batch_first=True, padding_value=-100) attention_mask = (input_ids != 0).long() images_seq_mask = pad_sequence(batch_images_seq_mask, batch_first=True, padding_value=False) images_spatial_crop = torch.stack(batch_images_spatial_crop) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "images": batch_images, "images_seq_mask": images_seq_mask, "images_spatial_crop": images_spatial_crop, } def main(): parser = argparse.ArgumentParser() parser.add_argument("--model_dir", type=str, default="deepseek_ocr") parser.add_argument("--dataset_path", type=str, required=True) parser.add_argument("--output_dir", type=str, default="./results") parser.add_argument("--lora_output", type=str, default="./lora_model_v2") parser.add_argument("--epochs", type=int, default=2) parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--gradient_accumulation", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--max_samples", type=int, default=None) args = parser.parse_args() # Load dataset dataset = load_dataset_local( args.dataset_path, max_samples=args.max_samples ) print(f"Train: {len(dataset['train'])}, Val: {len(dataset['validation'])}") # Load model print("Loading model...") model = AutoModel.from_pretrained( args.model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", ) processor = AutoProcessor.from_pretrained( args.model_dir, trust_remote_code=True ) # Setup LoRA print("Setting up LoRA...") lora_config = LoraConfig( r=16, lora_alpha=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0, bias="none", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Ensure model is in training mode model.train() # Enable gradients for base model for param in model.parameters(): param.requires_grad = False for name, param in model.named_parameters(): if 'lora' in name.lower(): param.requires_grad = True # Training args training_args = TrainingArguments( output_dir=args.output_dir, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation, num_train_epochs=args.epochs, learning_rate=args.learning_rate, bf16=True, logging_steps=10, save_strategy="epoch", eval_strategy="epoch", warmup_steps=50, weight_decay=0.01, lr_scheduler_type="cosine", remove_unused_columns=False, dataloader_num_workers=0, # Avoid multiprocessing issues gradient_checkpointing=False, # Disable - causes issues with this model ) # Data collator - processor is the tokenizer for DeepSeek-OCR collator = DeepSeekOCRDataCollator(processor) # Trainer trainer = Trainer( model=model, args=training_args, train_dataset=dataset['train'], eval_dataset=dataset['validation'], data_collator=collator, ) # Train print("Starting training...") trainer.train() # Save print(f"Saving to {args.lora_output}...") model.save_pretrained(args.lora_output) processor.save_pretrained(args.lora_output) print("Done!") if __name__ == "__main__": main()