sanskrit-ocr-lora / train.py
arpitingle's picture
Rename v2 files: remove v2 suffix
2bbf1b7
"""
DeepSeek OCR Fine-tuning for Sanskrit - Simplified Version
Works with transformers 4.45.0, peft, accelerate
"""
import os
import csv
import torch
import torchvision.transforms as T
from glob import glob
from pathlib import Path
from PIL import Image, ImageOps
from io import BytesIO
from dataclasses import dataclass
from typing import Any, Dict, List
from peft import LoraConfig, get_peft_model
from transformers import AutoModel, AutoProcessor, Trainer, TrainingArguments
from datasets import Dataset, DatasetDict
import argparse
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def load_dataset_local(dataset_path, train_size=0.8, val_size=0.1, max_samples=None):
"""Load dataset from local path"""
print(f"Loading dataset from: {dataset_path}")
labels_csv = os.path.join(dataset_path, "LABELS", "labels.csv")
labels_dict = {}
with open(labels_csv, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
if row:
labels_dict[row[0]] = row[1]
print(f"Loaded {len(labels_dict)} labels")
image_paths = sorted(glob(os.path.join(dataset_path, "IMAGES", "*.jpg")))
print(f"Found {len(image_paths)} images")
data = []
for img_path in image_paths:
img_name = Path(img_path).name
if img_name in labels_dict:
text = labels_dict[img_name].strip()
if text:
data.append({"image_path": img_path, "text": text})
print(f"Paired {len(data)} samples")
if max_samples and max_samples < len(data):
data = data[:max_samples]
dataset = Dataset.from_list(data)
# Split
train_test = dataset.train_test_split(test_size=(1 - train_size), seed=42)
val_test_ratio = val_size / (1 - train_size)
val_test = train_test['test'].train_test_split(test_size=(1 - val_test_ratio), seed=42)
return DatasetDict({
'train': train_test['train'],
'validation': val_test['train'],
'test': val_test['test']
})
class ImageTransform:
"""Image transform for normalization."""
def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
self.mean = mean
self.std = std
self.transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
def __call__(self, image):
return self.transform(image).float()
@dataclass
class DeepSeekOCRDataCollator:
"""Custom data collator for DeepSeek-OCR training"""
tokenizer: Any
image_size: int = 640
base_size: int = 1024
prompt: str = "<image>\nFree OCR. "
def __post_init__(self):
self.image_transform = ImageTransform()
self.image_token_id = 128815
self.patch_size = 16
self.downsample_ratio = 4
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
from torch.nn.utils.rnn import pad_sequence
import math
batch_input_ids = []
batch_labels = []
batch_images = []
batch_images_seq_mask = []
batch_images_spatial_crop = []
for feature in features:
image_path = feature["image_path"]
text = feature["text"]
# Load and process image
image = Image.open(image_path).convert("RGB")
# Create global view
global_view = ImageOps.pad(
image,
(self.base_size, self.base_size),
color=(128, 128, 128)
)
image_tensor = self.image_transform(global_view)
# Create empty patches tensor (no local crops for simplicity)
empty_patches = torch.zeros(1, 3, self.image_size, self.image_size)
# Build prompt
full_text = f"<|User|>{self.prompt}<|Assistant|>{text}"
# Tokenize
tokens = self.tokenizer.encode(full_text, add_special_tokens=False)
# Calculate image token positions
num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
num_image_tokens = (num_queries + 1) * num_queries + 1
# Build input_ids with image tokens
input_ids = [0] # BOS
images_seq_mask = [False]
# Add image tokens
input_ids.extend([self.image_token_id] * num_image_tokens)
images_seq_mask.extend([True] * num_image_tokens)
# Add text tokens
input_ids.extend(tokens)
images_seq_mask.extend([False] * len(tokens))
# Add EOS
input_ids.append(1)
images_seq_mask.append(False)
batch_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
batch_labels.append(torch.tensor(input_ids, dtype=torch.long))
# Model expects (patches, original) tuple
batch_images.append((empty_patches, image_tensor.unsqueeze(0)))
batch_images_seq_mask.append(torch.tensor(images_seq_mask, dtype=torch.bool))
# Spatial crop shape: (height_crops, width_crops)
batch_images_spatial_crop.append(torch.tensor([1, 1], dtype=torch.long))
# Pad sequences
input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
labels = pad_sequence(batch_labels, batch_first=True, padding_value=-100)
attention_mask = (input_ids != 0).long()
images_seq_mask = pad_sequence(batch_images_seq_mask, batch_first=True, padding_value=False)
images_spatial_crop = torch.stack(batch_images_spatial_crop)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"images": batch_images,
"images_seq_mask": images_seq_mask,
"images_spatial_crop": images_spatial_crop,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="deepseek_ocr")
parser.add_argument("--dataset_path", type=str, required=True)
parser.add_argument("--output_dir", type=str, default="./results")
parser.add_argument("--lora_output", type=str, default="./lora_model_v2")
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--gradient_accumulation", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--max_samples", type=int, default=None)
args = parser.parse_args()
# Load dataset
dataset = load_dataset_local(
args.dataset_path,
max_samples=args.max_samples
)
print(f"Train: {len(dataset['train'])}, Val: {len(dataset['validation'])}")
# Load model
print("Loading model...")
model = AutoModel.from_pretrained(
args.model_dir,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(
args.model_dir,
trust_remote_code=True
)
# Setup LoRA
print("Setting up LoRA...")
lora_config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0,
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Ensure model is in training mode
model.train()
# Enable gradients for base model
for param in model.parameters():
param.requires_grad = False
for name, param in model.named_parameters():
if 'lora' in name.lower():
param.requires_grad = True
# Training args
training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation,
num_train_epochs=args.epochs,
learning_rate=args.learning_rate,
bf16=True,
logging_steps=10,
save_strategy="epoch",
eval_strategy="epoch",
warmup_steps=50,
weight_decay=0.01,
lr_scheduler_type="cosine",
remove_unused_columns=False,
dataloader_num_workers=0, # Avoid multiprocessing issues
gradient_checkpointing=False, # Disable - causes issues with this model
)
# Data collator - processor is the tokenizer for DeepSeek-OCR
collator = DeepSeekOCRDataCollator(processor)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset['train'],
eval_dataset=dataset['validation'],
data_collator=collator,
)
# Train
print("Starting training...")
trainer.train()
# Save
print(f"Saving to {args.lora_output}...")
model.save_pretrained(args.lora_output)
processor.save_pretrained(args.lora_output)
print("Done!")
if __name__ == "__main__":
main()