arpitingle commited on
Commit
9aba698
·
1 Parent(s): bd3da16

v2: retrained on full dataset (6265 samples, 2 epochs, loss 4.7)

Browse files

- Retrained LoRA on complete Sanskrit_OCR_Parallel_Corpus (previously only 55%)
- Added train_v2.py: simplified training script without Unsloth dependency
- Fixed inference.py: use absolute model paths
- Updated run.py: added --local_dataset flag, fixed model_dir references

adapter_config.json CHANGED
@@ -4,10 +4,9 @@
4
  "arrow_config": null,
5
  "auto_mapping": {
6
  "base_model_class": "DeepseekOCRForCausalLM",
7
- "parent_library": "transformers_modules.deepseek_ocr.modeling_deepseekocr",
8
- "unsloth_fixed": true
9
  },
10
- "base_model_name_or_path": "deepseek_ocr",
11
  "bias": "none",
12
  "corda_config": null,
13
  "ensure_weight_tying": false,
@@ -33,16 +32,16 @@
33
  "rank_pattern": {},
34
  "revision": null,
35
  "target_modules": [
 
 
 
36
  "v_proj",
37
  "gate_proj",
38
- "q_proj",
39
  "k_proj",
40
- "o_proj",
41
- "down_proj",
42
- "up_proj"
43
  ],
44
  "target_parameters": null,
45
- "task_type": "CAUSAL_LM",
46
  "trainable_token_indices": null,
47
  "use_dora": false,
48
  "use_qalora": false,
 
4
  "arrow_config": null,
5
  "auto_mapping": {
6
  "base_model_class": "DeepseekOCRForCausalLM",
7
+ "parent_library": "transformers_modules.deepseek_ocr.modeling_deepseekocr"
 
8
  },
9
+ "base_model_name_or_path": "/home/ubuntu/deepseek_ocr",
10
  "bias": "none",
11
  "corda_config": null,
12
  "ensure_weight_tying": false,
 
32
  "rank_pattern": {},
33
  "revision": null,
34
  "target_modules": [
35
+ "up_proj",
36
+ "q_proj",
37
+ "o_proj",
38
  "v_proj",
39
  "gate_proj",
 
40
  "k_proj",
41
+ "down_proj"
 
 
42
  ],
43
  "target_parameters": null,
44
+ "task_type": null,
45
  "trainable_token_indices": null,
46
  "use_dora": false,
47
  "use_qalora": false,
adapter_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:997072bf68d91539c958713abe5e5a3b1baf2cbd5af3749919256d2b0c34bbcf
3
  size 310662536
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f95dfcbb52e9a0e95dfdc7754457c66b117e2f486ec214554b102aea6b78b9c
3
  size 310662536
inference.py CHANGED
@@ -58,11 +58,11 @@ def load_model_with_lora(base_model_path="deepseek_ocr", lora_path="./lora_model
58
  return model
59
 
60
 
61
- def run_inference(model, image_path, prompt="<image>\nFree OCR. "):
62
  print(f"Running inference on: {image_path}")
63
 
64
  processor = AutoProcessor.from_pretrained(
65
- "deepseek_ocr",
66
  trust_remote_code=True,
67
  )
68
 
@@ -99,7 +99,7 @@ if __name__ == "__main__":
99
 
100
  model = load_model_with_lora(args.base_model, args.lora)
101
 
102
- raw = run_inference(model, args.image)
103
 
104
  cleaned = clean_text(raw)
105
 
 
58
  return model
59
 
60
 
61
+ def run_inference(model, image_path, base_model_path="deepseek_ocr", prompt="<image>\nFree OCR. "):
62
  print(f"Running inference on: {image_path}")
63
 
64
  processor = AutoProcessor.from_pretrained(
65
+ base_model_path,
66
  trust_remote_code=True,
67
  )
68
 
 
99
 
100
  model = load_model_with_lora(args.base_model, args.lora)
101
 
102
+ raw = run_inference(model, args.image, args.base_model)
103
 
104
  cleaned = clean_text(raw)
105
 
run.py CHANGED
@@ -61,6 +61,7 @@ def load_model(model_path="deepseek_ocr", load_in_4bit=False):
61
  trust_remote_code=True,
62
  unsloth_force_compile=True,
63
  use_gradient_checkpointing="unsloth",
 
64
  )
65
 
66
  print("Model and tokenizer loaded successfully!")
@@ -88,36 +89,42 @@ def setup_lora(model):
88
  return model
89
 
90
  def load_and_prepare_dataset(dataset_name="snskrt/Sanskrit_OCR_Parallel_Corpus",
91
- train_size=0.8, val_size=0.1, max_samples=None, token=None):
 
92
  """
93
  Load and prepare the Sanskrit OCR dataset.
94
- This function downloads the entire repo, reads 'LABELS/labels.csv',
95
- and pairs images from the 'IMAGES/' folder.
96
  """
97
  import time
98
  print(f"Loading dataset: {dataset_name}")
99
 
100
  try:
101
- # 1. Download the entire dataset as a snapshot first with retry logic
102
- print("Downloading dataset snapshot (this may take a while)...")
103
- max_retries = 5
104
- for attempt in range(max_retries):
105
- try:
106
- dataset_path = snapshot_download(
107
- repo_id=dataset_name,
108
- repo_type="dataset",
109
- token=token,
110
- max_workers=1
111
- )
112
- break
113
- except Exception as e:
114
- if "429" in str(e) and attempt < max_retries - 1:
115
- wait_time = 60 * (attempt + 1)
116
- print(f"Rate limited. Waiting {wait_time} seconds before retry {attempt + 2}/{max_retries}...")
117
- time.sleep(wait_time)
118
- else:
119
- raise
120
- print(f"Dataset downloaded to: {dataset_path}")
 
 
 
 
 
121
 
122
  # 2. Read the labels.csv file from the LABELS directory
123
  labels_csv_path = os.path.join(dataset_path, "LABELS", "labels.csv")
@@ -294,7 +301,8 @@ def train_model(model, tokenizer, train_data, val_data,
294
  per_device_train_batch_size=2,
295
  gradient_accumulation_steps=4,
296
  learning_rate=2e-4,
297
- max_steps=None):
 
298
  """Train the model"""
299
  print("Starting training...")
300
 
@@ -337,7 +345,7 @@ def train_model(model, tokenizer, train_data, val_data,
337
  )
338
 
339
  # Load tokenizer for the data collator
340
- tokenizer_for_collator = AutoProcessor.from_pretrained("deepseek_ocr", trust_remote_code=True)
341
 
342
  # Import preprocessing functions from the cached model
343
  import sys
@@ -657,6 +665,8 @@ def main():
657
  help="HuggingFace token for authenticated access")
658
  parser.add_argument("--inspect_only", action="store_true",
659
  help="Only inspect dataset structure without training")
 
 
660
 
661
  args = parser.parse_args()
662
 
@@ -683,7 +693,8 @@ def main():
683
  train_size=args.train_size,
684
  val_size=args.val_size,
685
  max_samples=args.max_samples,
686
- token=hf_token
 
687
  )
688
 
689
  # If inspect only, exit here
@@ -714,7 +725,8 @@ def main():
714
  per_device_train_batch_size=args.batch_size,
715
  gradient_accumulation_steps=args.gradient_accumulation,
716
  learning_rate=args.learning_rate,
717
- max_steps=args.max_steps
 
718
  )
719
 
720
  # Step 7: Save model
 
61
  trust_remote_code=True,
62
  unsloth_force_compile=True,
63
  use_gradient_checkpointing="unsloth",
64
+ attn_implementation="eager",
65
  )
66
 
67
  print("Model and tokenizer loaded successfully!")
 
89
  return model
90
 
91
  def load_and_prepare_dataset(dataset_name="snskrt/Sanskrit_OCR_Parallel_Corpus",
92
+ train_size=0.8, val_size=0.1, max_samples=None, token=None,
93
+ local_path=None):
94
  """
95
  Load and prepare the Sanskrit OCR dataset.
96
+ This function reads 'LABELS/labels.csv' and pairs images from the 'IMAGES/' folder.
97
+ If local_path is provided, uses local dataset instead of downloading.
98
  """
99
  import time
100
  print(f"Loading dataset: {dataset_name}")
101
 
102
  try:
103
+ if local_path and os.path.exists(local_path):
104
+ # Use local dataset
105
+ dataset_path = local_path
106
+ print(f"Using local dataset from: {dataset_path}")
107
+ else:
108
+ # Download the entire dataset as a snapshot first with retry logic
109
+ print("Downloading dataset snapshot (this may take a while)...")
110
+ max_retries = 5
111
+ for attempt in range(max_retries):
112
+ try:
113
+ dataset_path = snapshot_download(
114
+ repo_id=dataset_name,
115
+ repo_type="dataset",
116
+ token=token,
117
+ max_workers=1
118
+ )
119
+ break
120
+ except Exception as e:
121
+ if "429" in str(e) and attempt < max_retries - 1:
122
+ wait_time = 60 * (attempt + 1)
123
+ print(f"Rate limited. Waiting {wait_time} seconds before retry {attempt + 2}/{max_retries}...")
124
+ time.sleep(wait_time)
125
+ else:
126
+ raise
127
+ print(f"Dataset downloaded to: {dataset_path}")
128
 
129
  # 2. Read the labels.csv file from the LABELS directory
130
  labels_csv_path = os.path.join(dataset_path, "LABELS", "labels.csv")
 
301
  per_device_train_batch_size=2,
302
  gradient_accumulation_steps=4,
303
  learning_rate=2e-4,
304
+ max_steps=None,
305
+ model_dir="deepseek_ocr"):
306
  """Train the model"""
307
  print("Starting training...")
308
 
 
345
  )
346
 
347
  # Load tokenizer for the data collator
348
+ tokenizer_for_collator = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
349
 
350
  # Import preprocessing functions from the cached model
351
  import sys
 
665
  help="HuggingFace token for authenticated access")
666
  parser.add_argument("--inspect_only", action="store_true",
667
  help="Only inspect dataset structure without training")
668
+ parser.add_argument("--local_dataset", type=str, default=None,
669
+ help="Path to local dataset directory (avoids re-downloading)")
670
 
671
  args = parser.parse_args()
672
 
 
693
  train_size=args.train_size,
694
  val_size=args.val_size,
695
  max_samples=args.max_samples,
696
+ token=hf_token,
697
+ local_path=args.local_dataset
698
  )
699
 
700
  # If inspect only, exit here
 
725
  per_device_train_batch_size=args.batch_size,
726
  gradient_accumulation_steps=args.gradient_accumulation,
727
  learning_rate=args.learning_rate,
728
+ max_steps=args.max_steps,
729
+ model_dir=model_path
730
  )
731
 
732
  # Step 7: Save model
tokenizer_config.json CHANGED
@@ -6655,7 +6655,7 @@
6655
  "legacy": true,
6656
  "model_max_length": 1000000000000000019884624838656,
6657
  "pad_token": "<|▁pad▁|>",
6658
- "tokenizer_class": "LlamaTokenizerFast",
6659
  "unk_token": null,
6660
  "use_default_system_prompt": false
6661
  }
 
6655
  "legacy": true,
6656
  "model_max_length": 1000000000000000019884624838656,
6657
  "pad_token": "<|▁pad▁|>",
6658
+ "tokenizer_class": "LlamaTokenizer",
6659
  "unk_token": null,
6660
  "use_default_system_prompt": false
6661
  }
train_v2.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek OCR Fine-tuning for Sanskrit - Simplified Version
3
+ Works with transformers 4.45.0, peft, accelerate
4
+ """
5
+
6
+ import os
7
+ import csv
8
+ import torch
9
+ import torchvision.transforms as T
10
+ from glob import glob
11
+ from pathlib import Path
12
+ from PIL import Image, ImageOps
13
+ from io import BytesIO
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List
16
+
17
+ from peft import LoraConfig, get_peft_model
18
+ from transformers import AutoModel, AutoProcessor, Trainer, TrainingArguments
19
+ from datasets import Dataset, DatasetDict
20
+ import argparse
21
+
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+
24
+
25
+ def load_dataset_local(dataset_path, train_size=0.8, val_size=0.1, max_samples=None):
26
+ """Load dataset from local path"""
27
+ print(f"Loading dataset from: {dataset_path}")
28
+
29
+ labels_csv = os.path.join(dataset_path, "LABELS", "labels.csv")
30
+ labels_dict = {}
31
+
32
+ with open(labels_csv, 'r', encoding='utf-8') as f:
33
+ reader = csv.reader(f)
34
+ header = next(reader)
35
+ for row in reader:
36
+ if row:
37
+ labels_dict[row[0]] = row[1]
38
+
39
+ print(f"Loaded {len(labels_dict)} labels")
40
+
41
+ image_paths = sorted(glob(os.path.join(dataset_path, "IMAGES", "*.jpg")))
42
+ print(f"Found {len(image_paths)} images")
43
+
44
+ data = []
45
+ for img_path in image_paths:
46
+ img_name = Path(img_path).name
47
+ if img_name in labels_dict:
48
+ text = labels_dict[img_name].strip()
49
+ if text:
50
+ data.append({"image_path": img_path, "text": text})
51
+
52
+ print(f"Paired {len(data)} samples")
53
+
54
+ if max_samples and max_samples < len(data):
55
+ data = data[:max_samples]
56
+
57
+ dataset = Dataset.from_list(data)
58
+
59
+ # Split
60
+ train_test = dataset.train_test_split(test_size=(1 - train_size), seed=42)
61
+ val_test_ratio = val_size / (1 - train_size)
62
+ val_test = train_test['test'].train_test_split(test_size=(1 - val_test_ratio), seed=42)
63
+
64
+ return DatasetDict({
65
+ 'train': train_test['train'],
66
+ 'validation': val_test['train'],
67
+ 'test': val_test['test']
68
+ })
69
+
70
+
71
+ class ImageTransform:
72
+ """Image transform for normalization."""
73
+ def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
74
+ self.mean = mean
75
+ self.std = std
76
+ self.transform = T.Compose([
77
+ T.ToTensor(),
78
+ T.Normalize(mean=mean, std=std)
79
+ ])
80
+
81
+ def __call__(self, image):
82
+ return self.transform(image).float()
83
+
84
+
85
+ @dataclass
86
+ class DeepSeekOCRDataCollator:
87
+ """Custom data collator for DeepSeek-OCR training"""
88
+ tokenizer: Any
89
+ image_size: int = 640
90
+ base_size: int = 1024
91
+ prompt: str = "<image>\nFree OCR. "
92
+
93
+ def __post_init__(self):
94
+ self.image_transform = ImageTransform()
95
+ self.image_token_id = 128815
96
+ self.patch_size = 16
97
+ self.downsample_ratio = 4
98
+
99
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
100
+ from torch.nn.utils.rnn import pad_sequence
101
+ import math
102
+
103
+ batch_input_ids = []
104
+ batch_labels = []
105
+ batch_images = []
106
+ batch_images_seq_mask = []
107
+ batch_images_spatial_crop = []
108
+
109
+ for feature in features:
110
+ image_path = feature["image_path"]
111
+ text = feature["text"]
112
+
113
+ # Load and process image
114
+ image = Image.open(image_path).convert("RGB")
115
+
116
+ # Create global view
117
+ global_view = ImageOps.pad(
118
+ image,
119
+ (self.base_size, self.base_size),
120
+ color=(128, 128, 128)
121
+ )
122
+ image_tensor = self.image_transform(global_view)
123
+
124
+ # Create empty patches tensor (no local crops for simplicity)
125
+ empty_patches = torch.zeros(1, 3, self.image_size, self.image_size)
126
+
127
+ # Build prompt
128
+ full_text = f"<|User|>{self.prompt}<|Assistant|>{text}"
129
+
130
+ # Tokenize
131
+ tokens = self.tokenizer.encode(full_text, add_special_tokens=False)
132
+
133
+ # Calculate image token positions
134
+ num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
135
+ num_image_tokens = (num_queries + 1) * num_queries + 1
136
+
137
+ # Build input_ids with image tokens
138
+ input_ids = [0] # BOS
139
+ images_seq_mask = [False]
140
+
141
+ # Add image tokens
142
+ input_ids.extend([self.image_token_id] * num_image_tokens)
143
+ images_seq_mask.extend([True] * num_image_tokens)
144
+
145
+ # Add text tokens
146
+ input_ids.extend(tokens)
147
+ images_seq_mask.extend([False] * len(tokens))
148
+
149
+ # Add EOS
150
+ input_ids.append(1)
151
+ images_seq_mask.append(False)
152
+
153
+ batch_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
154
+ batch_labels.append(torch.tensor(input_ids, dtype=torch.long))
155
+ # Model expects (patches, original) tuple
156
+ batch_images.append((empty_patches, image_tensor.unsqueeze(0)))
157
+ batch_images_seq_mask.append(torch.tensor(images_seq_mask, dtype=torch.bool))
158
+ # Spatial crop shape: (height_crops, width_crops)
159
+ batch_images_spatial_crop.append(torch.tensor([1, 1], dtype=torch.long))
160
+
161
+ # Pad sequences
162
+ input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
163
+ labels = pad_sequence(batch_labels, batch_first=True, padding_value=-100)
164
+ attention_mask = (input_ids != 0).long()
165
+ images_seq_mask = pad_sequence(batch_images_seq_mask, batch_first=True, padding_value=False)
166
+ images_spatial_crop = torch.stack(batch_images_spatial_crop)
167
+
168
+ return {
169
+ "input_ids": input_ids,
170
+ "attention_mask": attention_mask,
171
+ "labels": labels,
172
+ "images": batch_images,
173
+ "images_seq_mask": images_seq_mask,
174
+ "images_spatial_crop": images_spatial_crop,
175
+ }
176
+
177
+
178
+ def main():
179
+ parser = argparse.ArgumentParser()
180
+ parser.add_argument("--model_dir", type=str, default="deepseek_ocr")
181
+ parser.add_argument("--dataset_path", type=str, required=True)
182
+ parser.add_argument("--output_dir", type=str, default="./results")
183
+ parser.add_argument("--lora_output", type=str, default="./lora_model_v2")
184
+ parser.add_argument("--epochs", type=int, default=2)
185
+ parser.add_argument("--batch_size", type=int, default=2)
186
+ parser.add_argument("--gradient_accumulation", type=int, default=4)
187
+ parser.add_argument("--learning_rate", type=float, default=2e-4)
188
+ parser.add_argument("--max_samples", type=int, default=None)
189
+ args = parser.parse_args()
190
+
191
+ # Load dataset
192
+ dataset = load_dataset_local(
193
+ args.dataset_path,
194
+ max_samples=args.max_samples
195
+ )
196
+ print(f"Train: {len(dataset['train'])}, Val: {len(dataset['validation'])}")
197
+
198
+ # Load model
199
+ print("Loading model...")
200
+ model = AutoModel.from_pretrained(
201
+ args.model_dir,
202
+ trust_remote_code=True,
203
+ torch_dtype=torch.bfloat16,
204
+ device_map="auto",
205
+ )
206
+
207
+ processor = AutoProcessor.from_pretrained(
208
+ args.model_dir,
209
+ trust_remote_code=True
210
+ )
211
+
212
+ # Setup LoRA
213
+ print("Setting up LoRA...")
214
+ lora_config = LoraConfig(
215
+ r=16,
216
+ lora_alpha=16,
217
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
218
+ lora_dropout=0,
219
+ bias="none",
220
+ )
221
+
222
+ model = get_peft_model(model, lora_config)
223
+ model.print_trainable_parameters()
224
+
225
+ # Ensure model is in training mode
226
+ model.train()
227
+
228
+ # Enable gradients for base model
229
+ for param in model.parameters():
230
+ param.requires_grad = False
231
+ for name, param in model.named_parameters():
232
+ if 'lora' in name.lower():
233
+ param.requires_grad = True
234
+
235
+ # Training args
236
+ training_args = TrainingArguments(
237
+ output_dir=args.output_dir,
238
+ per_device_train_batch_size=args.batch_size,
239
+ gradient_accumulation_steps=args.gradient_accumulation,
240
+ num_train_epochs=args.epochs,
241
+ learning_rate=args.learning_rate,
242
+ bf16=True,
243
+ logging_steps=10,
244
+ save_strategy="epoch",
245
+ eval_strategy="epoch",
246
+ warmup_steps=50,
247
+ weight_decay=0.01,
248
+ lr_scheduler_type="cosine",
249
+ remove_unused_columns=False,
250
+ dataloader_num_workers=0, # Avoid multiprocessing issues
251
+ gradient_checkpointing=False, # Disable - causes issues with this model
252
+ )
253
+
254
+ # Data collator - processor is the tokenizer for DeepSeek-OCR
255
+ collator = DeepSeekOCRDataCollator(processor)
256
+
257
+ # Trainer
258
+ trainer = Trainer(
259
+ model=model,
260
+ args=training_args,
261
+ train_dataset=dataset['train'],
262
+ eval_dataset=dataset['validation'],
263
+ data_collator=collator,
264
+ )
265
+
266
+ # Train
267
+ print("Starting training...")
268
+ trainer.train()
269
+
270
+ # Save
271
+ print(f"Saving to {args.lora_output}...")
272
+ model.save_pretrained(args.lora_output)
273
+ processor.save_pretrained(args.lora_output)
274
+
275
+ print("Done!")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()