| | """ |
| | DeepSeek OCR Fine-tuning for Sanskrit - Simplified Version |
| | Works with transformers 4.45.0, peft, accelerate |
| | """ |
| |
|
| | import os |
| | import csv |
| | import torch |
| | import torchvision.transforms as T |
| | from glob import glob |
| | from pathlib import Path |
| | from PIL import Image, ImageOps |
| | from io import BytesIO |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List |
| |
|
| | from peft import LoraConfig, get_peft_model |
| | from transformers import AutoModel, AutoProcessor, Trainer, TrainingArguments |
| | from datasets import Dataset, DatasetDict |
| | import argparse |
| |
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| |
|
| | def load_dataset_local(dataset_path, train_size=0.8, val_size=0.1, max_samples=None): |
| | """Load dataset from local path""" |
| | print(f"Loading dataset from: {dataset_path}") |
| | |
| | labels_csv = os.path.join(dataset_path, "LABELS", "labels.csv") |
| | labels_dict = {} |
| | |
| | with open(labels_csv, 'r', encoding='utf-8') as f: |
| | reader = csv.reader(f) |
| | header = next(reader) |
| | for row in reader: |
| | if row: |
| | labels_dict[row[0]] = row[1] |
| | |
| | print(f"Loaded {len(labels_dict)} labels") |
| | |
| | image_paths = sorted(glob(os.path.join(dataset_path, "IMAGES", "*.jpg"))) |
| | print(f"Found {len(image_paths)} images") |
| | |
| | data = [] |
| | for img_path in image_paths: |
| | img_name = Path(img_path).name |
| | if img_name in labels_dict: |
| | text = labels_dict[img_name].strip() |
| | if text: |
| | data.append({"image_path": img_path, "text": text}) |
| | |
| | print(f"Paired {len(data)} samples") |
| | |
| | if max_samples and max_samples < len(data): |
| | data = data[:max_samples] |
| | |
| | dataset = Dataset.from_list(data) |
| | |
| | |
| | train_test = dataset.train_test_split(test_size=(1 - train_size), seed=42) |
| | val_test_ratio = val_size / (1 - train_size) |
| | val_test = train_test['test'].train_test_split(test_size=(1 - val_test_ratio), seed=42) |
| | |
| | return DatasetDict({ |
| | 'train': train_test['train'], |
| | 'validation': val_test['train'], |
| | 'test': val_test['test'] |
| | }) |
| |
|
| |
|
| | class ImageTransform: |
| | """Image transform for normalization.""" |
| | def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): |
| | self.mean = mean |
| | self.std = std |
| | self.transform = T.Compose([ |
| | T.ToTensor(), |
| | T.Normalize(mean=mean, std=std) |
| | ]) |
| | |
| | def __call__(self, image): |
| | return self.transform(image).float() |
| |
|
| |
|
| | @dataclass |
| | class DeepSeekOCRDataCollator: |
| | """Custom data collator for DeepSeek-OCR training""" |
| | tokenizer: Any |
| | image_size: int = 640 |
| | base_size: int = 1024 |
| | prompt: str = "<image>\nFree OCR. " |
| | |
| | def __post_init__(self): |
| | self.image_transform = ImageTransform() |
| | self.image_token_id = 128815 |
| | self.patch_size = 16 |
| | self.downsample_ratio = 4 |
| | |
| | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
| | from torch.nn.utils.rnn import pad_sequence |
| | import math |
| | |
| | batch_input_ids = [] |
| | batch_labels = [] |
| | batch_images = [] |
| | batch_images_seq_mask = [] |
| | batch_images_spatial_crop = [] |
| | |
| | for feature in features: |
| | image_path = feature["image_path"] |
| | text = feature["text"] |
| | |
| | |
| | image = Image.open(image_path).convert("RGB") |
| | |
| | |
| | global_view = ImageOps.pad( |
| | image, |
| | (self.base_size, self.base_size), |
| | color=(128, 128, 128) |
| | ) |
| | image_tensor = self.image_transform(global_view) |
| | |
| | |
| | empty_patches = torch.zeros(1, 3, self.image_size, self.image_size) |
| | |
| | |
| | full_text = f"<|User|>{self.prompt}<|Assistant|>{text}" |
| | |
| | |
| | tokens = self.tokenizer.encode(full_text, add_special_tokens=False) |
| | |
| | |
| | num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) |
| | num_image_tokens = (num_queries + 1) * num_queries + 1 |
| | |
| | |
| | input_ids = [0] |
| | images_seq_mask = [False] |
| | |
| | |
| | input_ids.extend([self.image_token_id] * num_image_tokens) |
| | images_seq_mask.extend([True] * num_image_tokens) |
| | |
| | |
| | input_ids.extend(tokens) |
| | images_seq_mask.extend([False] * len(tokens)) |
| | |
| | |
| | input_ids.append(1) |
| | images_seq_mask.append(False) |
| | |
| | batch_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) |
| | batch_labels.append(torch.tensor(input_ids, dtype=torch.long)) |
| | |
| | batch_images.append((empty_patches, image_tensor.unsqueeze(0))) |
| | batch_images_seq_mask.append(torch.tensor(images_seq_mask, dtype=torch.bool)) |
| | |
| | batch_images_spatial_crop.append(torch.tensor([1, 1], dtype=torch.long)) |
| | |
| | |
| | input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0) |
| | labels = pad_sequence(batch_labels, batch_first=True, padding_value=-100) |
| | attention_mask = (input_ids != 0).long() |
| | images_seq_mask = pad_sequence(batch_images_seq_mask, batch_first=True, padding_value=False) |
| | images_spatial_crop = torch.stack(batch_images_spatial_crop) |
| | |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "labels": labels, |
| | "images": batch_images, |
| | "images_seq_mask": images_seq_mask, |
| | "images_spatial_crop": images_spatial_crop, |
| | } |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--model_dir", type=str, default="deepseek_ocr") |
| | parser.add_argument("--dataset_path", type=str, required=True) |
| | parser.add_argument("--output_dir", type=str, default="./results") |
| | parser.add_argument("--lora_output", type=str, default="./lora_model_v2") |
| | parser.add_argument("--epochs", type=int, default=2) |
| | parser.add_argument("--batch_size", type=int, default=2) |
| | parser.add_argument("--gradient_accumulation", type=int, default=4) |
| | parser.add_argument("--learning_rate", type=float, default=2e-4) |
| | parser.add_argument("--max_samples", type=int, default=None) |
| | args = parser.parse_args() |
| | |
| | |
| | dataset = load_dataset_local( |
| | args.dataset_path, |
| | max_samples=args.max_samples |
| | ) |
| | print(f"Train: {len(dataset['train'])}, Val: {len(dataset['validation'])}") |
| | |
| | |
| | print("Loading model...") |
| | model = AutoModel.from_pretrained( |
| | args.model_dir, |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto", |
| | ) |
| | |
| | processor = AutoProcessor.from_pretrained( |
| | args.model_dir, |
| | trust_remote_code=True |
| | ) |
| | |
| | |
| | print("Setting up LoRA...") |
| | lora_config = LoraConfig( |
| | r=16, |
| | lora_alpha=16, |
| | target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| | lora_dropout=0, |
| | bias="none", |
| | ) |
| | |
| | model = get_peft_model(model, lora_config) |
| | model.print_trainable_parameters() |
| | |
| | |
| | model.train() |
| | |
| | |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| | for name, param in model.named_parameters(): |
| | if 'lora' in name.lower(): |
| | param.requires_grad = True |
| | |
| | |
| | training_args = TrainingArguments( |
| | output_dir=args.output_dir, |
| | per_device_train_batch_size=args.batch_size, |
| | gradient_accumulation_steps=args.gradient_accumulation, |
| | num_train_epochs=args.epochs, |
| | learning_rate=args.learning_rate, |
| | bf16=True, |
| | logging_steps=10, |
| | save_strategy="epoch", |
| | eval_strategy="epoch", |
| | warmup_steps=50, |
| | weight_decay=0.01, |
| | lr_scheduler_type="cosine", |
| | remove_unused_columns=False, |
| | dataloader_num_workers=0, |
| | gradient_checkpointing=False, |
| | ) |
| | |
| | |
| | collator = DeepSeekOCRDataCollator(processor) |
| | |
| | |
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=dataset['train'], |
| | eval_dataset=dataset['validation'], |
| | data_collator=collator, |
| | ) |
| | |
| | |
| | print("Starting training...") |
| | trainer.train() |
| | |
| | |
| | print(f"Saving to {args.lora_output}...") |
| | model.save_pretrained(args.lora_output) |
| | processor.save_pretrained(args.lora_output) |
| | |
| | print("Done!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|