iammraat commited on
Commit
b067fca
Β·
verified Β·
1 Parent(s): 186539d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (1).py +896 -0
  2. requirements.txt +19 -0
app (1).py ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # import gradio as gr
3
+ # from ultralytics import YOLO
4
+ # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
+ # from PIL import Image, ImageDraw
6
+ # import torch
7
+ # import logging
8
+ # from datetime import datetime
9
+ # import os
10
+ # import warnings
11
+ # import time
12
+
13
+ # # Suppress progress bar and unnecessary logs
14
+ # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
15
+ # warnings.filterwarnings('ignore')
16
+ # logging.getLogger('transformers').setLevel(logging.ERROR)
17
+ # logging.getLogger('ultralytics').setLevel(logging.ERROR)
18
+
19
+ # # Setup logging
20
+ # logging.basicConfig(
21
+ # level=logging.INFO,
22
+ # format='%(asctime)s - %(levelname)s - %(message)s'
23
+ # )
24
+ # logger = logging.getLogger(__name__)
25
+
26
+ # logger.info("Starting model loading...")
27
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ # logger.info(f"Using device: {device}")
29
+
30
+ # # --- ROBUST MODEL LOADING FUNCTION ---
31
+ # def load_model_with_retry(model_class, model_name, token=None, retries=5, delay=5):
32
+ # """Attempts to load a HF model with retries to handle network timeouts."""
33
+ # for attempt in range(retries):
34
+ # try:
35
+ # logger.info(f"Loading {model_name} (Attempt {attempt + 1}/{retries})...")
36
+ # if "Processor" in str(model_class):
37
+ # return model_class.from_pretrained(model_name, token=token)
38
+ # else:
39
+ # return model_class.from_pretrained(model_name, token=token).to(device)
40
+ # except Exception as e:
41
+ # logger.warning(f"Failed to load {model_name}: {e}")
42
+ # if attempt < retries - 1:
43
+ # logger.info(f"Retrying in {delay} seconds...")
44
+ # time.sleep(delay)
45
+ # else:
46
+ # logger.error(f"Given up on loading {model_name} after {retries} attempts.")
47
+ # raise e
48
+
49
+ # try:
50
+ # # 1. Load YOLO Models (Local Files)
51
+ # region_model_file = 'regions.pt'
52
+ # line_model_file = 'lines.pt'
53
+
54
+ # # Simple check for local files
55
+ # if not os.path.exists(region_model_file):
56
+ # # Check current directory listing just in case
57
+ # for file in os.listdir('.'):
58
+ # if 'region' in file.lower() and file.endswith('.pt'): region_model_file = file
59
+ # elif 'line' in file.lower() and file.endswith('.pt'): line_model_file = file
60
+
61
+ # if not os.path.exists(region_model_file) or not os.path.exists(line_model_file):
62
+ # raise FileNotFoundError("YOLO .pt files (regions.pt/lines.pt) not found.")
63
+
64
+ # logger.info("Loading YOLO models...")
65
+ # region_model = YOLO(region_model_file)
66
+ # line_model = YOLO(line_model_file)
67
+ # logger.info("βœ“ YOLO models loaded")
68
+
69
+ # # 2. Load TrOCR with Retries
70
+ # hf_token = os.getenv("HF_TOKEN")
71
+
72
+ # processor = load_model_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", token=hf_token)
73
+ # logger.info("βœ“ TrOCR processor loaded")
74
+
75
+ # trocr_model = load_model_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", token=hf_token)
76
+ # logger.info("βœ“ TrOCR model loaded")
77
+
78
+ # logger.info("All models loaded successfully!")
79
+
80
+ # except Exception as e:
81
+ # logger.error(f"CRITICAL ERROR loading models: {str(e)}")
82
+ # raise
83
+
84
+ # # --- OCR HELPER ---
85
+ # def run_trocr(image_slice, processor, model, device):
86
+ # """Runs TrOCR on a single cropped image slice."""
87
+ # pixel_values = processor(images=image_slice, return_tensors="pt").pixel_values.to(device)
88
+ # generated_ids = model.generate(pixel_values)
89
+ # return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
90
+
91
+ # def process_document(image):
92
+ # """Process uploaded document image and extract handwritten text with visualization."""
93
+ # timestamp = datetime.now().strftime("%H:%M:%S")
94
+ # log_output = []
95
+
96
+ # def add_log(message, level="INFO"):
97
+ # log_msg = f"[{timestamp}] {level}: {message}"
98
+ # log_output.append(log_msg)
99
+ # if level == "ERROR":
100
+ # logger.error(message)
101
+ # else:
102
+ # logger.info(message)
103
+
104
+ # add_log("Starting document processing")
105
+
106
+ # if image is None:
107
+ # add_log("No image provided", "ERROR")
108
+ # return None, "Please upload an image", "\n".join(log_output)
109
+
110
+ # try:
111
+ # # Prepare Image
112
+ # if not isinstance(image, Image.Image):
113
+ # img = Image.open(image).convert("RGB")
114
+ # else:
115
+ # img = image.convert("RGB")
116
+
117
+ # # Create a drawing context for the debug image
118
+ # debug_img = img.copy()
119
+ # draw = ImageDraw.Draw(debug_img)
120
+
121
+ # width, height = img.size
122
+ # add_log(f"Image size: {width}x{height} pixels")
123
+
124
+ # all_lines = []
125
+
126
+ # # --- STRATEGY 1: Region Detection ---
127
+ # add_log("Strategy 1: Running region detection...")
128
+ # region_results = region_model(img, conf=0.2, imgsz=1024, verbose=False)
129
+ # regions = region_results[0].boxes
130
+ # num_regions = len(regions)
131
+ # add_log(f"βœ“ Found {num_regions} potential text region(s)")
132
+
133
+ # found_lines_in_regions = False
134
+
135
+ # if num_regions > 0:
136
+ # for region_idx, region in enumerate(regions):
137
+ # add_log(f"Processing region {region_idx + 1}/{num_regions}")
138
+
139
+ # # Get coordinates
140
+ # rx1, ry1, rx2, ry2 = map(int, region.xyxy[0])
141
+
142
+ # # Filter small artifacts
143
+ # if (rx2 - rx1) < 50 or (ry2 - ry1) < 50:
144
+ # add_log(f" Skipping tiny artifact: {rx2-rx1}x{ry2-ry1} px")
145
+ # continue
146
+
147
+ # # Draw GREEN box for Region
148
+ # draw.rectangle([rx1, ry1, rx2, ry2], outline="green", width=5)
149
+
150
+ # # Crop Region
151
+ # region_crop = img.crop((rx1, ry1, rx2, ry2))
152
+
153
+ # # Detect lines in this region
154
+ # line_results = line_model(region_crop, conf=0.2, imgsz=1024, verbose=False)
155
+ # lines = line_results[0].boxes
156
+ # num_lines = len(lines)
157
+ # add_log(f" βœ“ Found {num_lines} line(s) in region")
158
+
159
+ # if num_lines > 0:
160
+ # found_lines_in_regions = True
161
+
162
+ # # Sort lines by Y position
163
+ # lines_sorted = sorted(lines, key=lambda b: b.xyxy[0][1])
164
+
165
+ # for line_idx, line in enumerate(lines_sorted):
166
+ # lx1, ly1, lx2, ly2 = map(int, line.xyxy[0])
167
+
168
+ # # Translate line coordinates back to original image space for drawing
169
+ # global_lx1 = rx1 + lx1
170
+ # global_ly1 = ry1 + ly1
171
+ # global_lx2 = rx1 + lx2
172
+ # global_ly2 = ry1 + ly2
173
+
174
+ # # Draw RED box for Line
175
+ # draw.rectangle([global_lx1, global_ly1, global_lx2, global_ly2], outline="red", width=3)
176
+
177
+ # # OCR
178
+ # line_crop = region_crop.crop((lx1, ly1, lx2, ly2))
179
+ # text = run_trocr(line_crop, processor, trocr_model, device)
180
+ # add_log(f" Line {line_idx + 1}: '{text}'")
181
+ # all_lines.append(text)
182
+
183
+ # # --- STRATEGY 2: Fallback to Full Page ---
184
+ # if not found_lines_in_regions:
185
+ # add_log("⚠️ Region detection yielded no lines. Switching to Fallback Strategy...", "WARNING")
186
+ # add_log("Strategy 2: Running line detection on full page")
187
+
188
+ # line_results = line_model(img, conf=0.2, imgsz=1024, verbose=False)
189
+ # lines = line_results[0].boxes
190
+ # num_lines = len(lines)
191
+ # add_log(f"βœ“ Fallback found {num_lines} line(s) on full page")
192
+
193
+ # if num_lines > 0:
194
+ # lines_sorted = sorted(lines, key=lambda b: b.xyxy[0][1])
195
+
196
+ # for line_idx, line in enumerate(lines_sorted):
197
+ # lx1, ly1, lx2, ly2 = map(int, line.xyxy[0])
198
+
199
+ # # Draw RED box for Line (on full image)
200
+ # draw.rectangle([lx1, ly1, lx2, ly2], outline="red", width=3)
201
+
202
+ # line_crop = img.crop((lx1, ly1, lx2, ly2))
203
+ # text = run_trocr(line_crop, processor, trocr_model, device)
204
+ # add_log(f" Line {line_idx + 1}: '{text}'")
205
+ # all_lines.append(text)
206
+
207
+ # if not all_lines:
208
+ # add_log("Failed to detect any text lines in both strategies", "ERROR")
209
+ # return debug_img, "No text could be extracted.", "\n".join(log_output)
210
+
211
+ # add_log(f"βœ“ Success! Extracted {len(all_lines)} total line(s)")
212
+ # final_text = '\n'.join(all_lines)
213
+
214
+ # return debug_img, final_text, "\n".join(log_output)
215
+
216
+ # except Exception as e:
217
+ # error_msg = f"Error processing image: {str(e)}"
218
+ # add_log(error_msg, "ERROR")
219
+ # logger.exception("Full error traceback:")
220
+ # # Return the original image if debug creation failed
221
+ # return image, f"Error: {str(e)}", "\n".join(log_output)
222
+
223
+ # # Create Gradio interface
224
+ # demo = gr.Interface(
225
+ # fn=process_document,
226
+ # inputs=gr.Image(type="pil", label="Upload Handwritten Document"),
227
+ # outputs=[
228
+ # gr.Image(type="pil", label="Debug Visualization (Green=Region, Red=Lines)"),
229
+ # gr.Textbox(label="Extracted Text", lines=10),
230
+ # gr.Textbox(label="Processing Logs", lines=15)
231
+ # ],
232
+ # title="πŸ“ Handwritten Text Recognition (HTR) with Debugging",
233
+ # description="""
234
+ # Upload an image of a handwritten document.
235
+
236
+ # **Visualization Key:**
237
+ # - 🟩 **Green Box:** The broad region identified as containing text.
238
+ # - πŸŸ₯ **Red Box:** The specific line of text sent to the OCR engine.
239
+ # """,
240
+ # flagging_mode="never",
241
+ # theme=gr.themes.Soft()
242
+ # )
243
+
244
+ # if __name__ == "__main__":
245
+ # logger.info("Launching Gradio interface...")
246
+ # demo.launch()
247
+
248
+
249
+
250
+
251
+
252
+
253
+
254
+
255
+
256
+
257
+
258
+
259
+
260
+
261
+
262
+
263
+ # import gradio as gr
264
+ # from ultralytics import YOLO
265
+ # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
266
+ # from PIL import Image, ImageDraw, ImageFont
267
+ # import torch
268
+ # import logging
269
+ # from datetime import datetime
270
+ # import os
271
+ # import warnings
272
+ # import time
273
+
274
+ # # Suppress progress bar and unnecessary logs
275
+ # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
276
+ # warnings.filterwarnings('ignore')
277
+ # logging.getLogger('transformers').setLevel(logging.ERROR)
278
+ # logging.getLogger('ultralytics').setLevel(logging.ERROR)
279
+
280
+ # # Setup logging
281
+ # logging.basicConfig(
282
+ # level=logging.INFO,
283
+ # format='%(asctime)s - %(levelname)s - %(message)s'
284
+ # )
285
+ # logger = logging.getLogger(__name__)
286
+
287
+ # logger.info("Starting model loading...")
288
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
289
+ # logger.info(f"Using device: {device}")
290
+
291
+ # # --- ROBUST MODEL LOADING FUNCTION ---
292
+ # def load_model_with_retry(model_class, model_name, token=None, retries=5, delay=5):
293
+ # """Attempts to load a HF model with retries to handle network timeouts."""
294
+ # for attempt in range(retries):
295
+ # try:
296
+ # logger.info(f"Loading {model_name} (Attempt {attempt + 1}/{retries})...")
297
+ # if "Processor" in str(model_class):
298
+ # return model_class.from_pretrained(model_name, token=token)
299
+ # else:
300
+ # return model_class.from_pretrained(model_name, token=token).to(device)
301
+ # except Exception as e:
302
+ # logger.warning(f"Failed to load {model_name}: {e}")
303
+ # if attempt < retries - 1:
304
+ # logger.info(f"Retrying in {delay} seconds...")
305
+ # time.sleep(delay)
306
+ # else:
307
+ # logger.error(f"Given up on loading {model_name} after {retries} attempts.")
308
+ # raise e
309
+
310
+ # try:
311
+ # # 1. Load YOLO Models (Local Files)
312
+ # region_model_file = 'regions.pt'
313
+ # line_model_file = 'lines.pt'
314
+
315
+ # # Simple check for local files
316
+ # if not os.path.exists(region_model_file):
317
+ # for file in os.listdir('.'):
318
+ # if 'region' in file.lower() and file.endswith('.pt'): region_model_file = file
319
+ # elif 'line' in file.lower() and file.endswith('.pt'): line_model_file = file
320
+
321
+ # if not os.path.exists(region_model_file) or not os.path.exists(line_model_file):
322
+ # raise FileNotFoundError("YOLO .pt files (regions.pt/lines.pt) not found.")
323
+
324
+ # logger.info("Loading YOLO models...")
325
+ # region_model = YOLO(region_model_file)
326
+ # line_model = YOLO(line_model_file)
327
+ # logger.info("βœ“ YOLO models loaded")
328
+
329
+ # # 2. Load TrOCR with Retries
330
+ # hf_token = os.getenv("HF_TOKEN")
331
+
332
+ # processor = load_model_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", token=hf_token)
333
+ # logger.info("βœ“ TrOCR processor loaded")
334
+
335
+ # trocr_model = load_model_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", token=hf_token)
336
+ # logger.info("βœ“ TrOCR model loaded")
337
+
338
+ # logger.info("All models loaded successfully!")
339
+
340
+ # except Exception as e:
341
+ # logger.error(f"CRITICAL ERROR loading models: {str(e)}")
342
+ # raise
343
+
344
+ # # --- OCR HELPER ---
345
+ # def run_trocr(image_slice, processor, model, device):
346
+ # """Runs TrOCR on a single cropped image slice."""
347
+ # pixel_values = processor(images=image_slice, return_tensors="pt").pixel_values.to(device)
348
+ # generated_ids = model.generate(pixel_values)
349
+ # return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
350
+
351
+ # def process_document(image, enable_debug_crops=False):
352
+ # """Process uploaded document image and extract handwritten text with visualization."""
353
+ # timestamp = datetime.now().strftime("%H:%M:%S")
354
+ # log_output = []
355
+
356
+ # def add_log(message, level="INFO"):
357
+ # log_msg = f"[{timestamp}] {level}: {message}"
358
+ # log_output.append(log_msg)
359
+ # if level == "ERROR":
360
+ # logger.error(message)
361
+ # else:
362
+ # logger.info(message)
363
+
364
+ # add_log("Starting document processing")
365
+
366
+ # if image is None:
367
+ # add_log("No image provided", "ERROR")
368
+ # return None, "Please upload an image", "\n".join(log_output)
369
+
370
+ # try:
371
+ # # Prepare Image
372
+ # if not isinstance(image, Image.Image):
373
+ # img = Image.open(image).convert("RGB")
374
+ # else:
375
+ # img = image.convert("RGB")
376
+
377
+ # # Create a drawing context for the debug image
378
+ # debug_img = img.copy()
379
+ # draw = ImageDraw.Draw(debug_img)
380
+
381
+ # width, height = img.size
382
+ # add_log(f"Image size: {width}x{height} pixels")
383
+
384
+ # all_lines = []
385
+ # debug_crops_dir = "debug_crops"
386
+
387
+ # if enable_debug_crops:
388
+ # os.makedirs(debug_crops_dir, exist_ok=True)
389
+ # add_log(f"Debug crops will be saved to {debug_crops_dir}/")
390
+
391
+ # # --- STRATEGY 1: Region Detection ---
392
+ # add_log("Strategy 1: Running region detection...")
393
+ # region_results = region_model(img, conf=0.2, imgsz=1024, verbose=False)
394
+ # regions = region_results[0].boxes
395
+ # num_regions = len(regions)
396
+ # add_log(f"βœ“ Found {num_regions} potential text region(s)")
397
+
398
+ # found_lines_in_regions = False
399
+
400
+ # if num_regions > 0:
401
+ # for region_idx, region in enumerate(regions):
402
+ # add_log(f"Processing region {region_idx + 1}/{num_regions}")
403
+
404
+ # # FIX 1: Use round() instead of int() to minimize precision loss
405
+ # rx1, ry1, rx2, ry2 = map(round, region.xyxy[0].tolist())
406
+
407
+ # # Calculate region dimensions
408
+ # region_width = rx2 - rx1
409
+ # region_height = ry2 - ry1
410
+
411
+ # add_log(f" Region coords: ({rx1}, {ry1}) β†’ ({rx2}, {ry2}), size: {region_width}x{region_height}")
412
+
413
+ # # Filter small artifacts
414
+ # if region_width < 50 or region_height < 50:
415
+ # add_log(f" Skipping tiny artifact: {region_width}x{region_height} px")
416
+ # continue
417
+
418
+ # # FIX 2: Add padding to region crops to avoid edge effects
419
+ # padding = 10
420
+ # padded_rx1 = max(0, rx1 - padding)
421
+ # padded_ry1 = max(0, ry1 - padding)
422
+ # padded_rx2 = min(width, rx2 + padding)
423
+ # padded_ry2 = min(height, ry2 + padding)
424
+
425
+ # add_log(f" Padded coords: ({padded_rx1}, {padded_ry1}) β†’ ({padded_rx2}, {padded_ry2})")
426
+
427
+ # # Draw GREEN box for Region (original bounds, not padded)
428
+ # draw.rectangle([rx1, ry1, rx2, ry2], outline="green", width=5)
429
+
430
+ # # Crop Region with padding
431
+ # region_crop = img.crop((padded_rx1, padded_ry1, padded_rx2, padded_ry2))
432
+
433
+ # if enable_debug_crops:
434
+ # region_crop.save(f"{debug_crops_dir}/region_{region_idx:02d}.png")
435
+
436
+ # # Detect lines in this region
437
+ # add_log(f" Running line detection on region crop ({region_crop.size[0]}x{region_crop.size[1]})...")
438
+ # line_results = line_model(region_crop, conf=0.2, imgsz=1024, verbose=False)
439
+ # lines_data = line_results[0].boxes.xyxy.cpu().numpy()
440
+ # num_lines = len(lines_data)
441
+ # add_log(f" βœ“ Found {num_lines} line(s) in region")
442
+
443
+ # if num_lines > 0:
444
+ # found_lines_in_regions = True
445
+
446
+ # # Sort lines by Y position (index 1 of xyxy)
447
+ # sorted_indices = lines_data[:, 1].argsort()
448
+
449
+ # for line_idx, idx in enumerate(sorted_indices):
450
+ # # FIX 3: Use round() for line coordinates too
451
+ # lx1, ly1, lx2, ly2 = map(round, lines_data[idx].tolist())
452
+
453
+ # line_width = lx2 - lx1
454
+ # line_height = ly2 - ly1
455
+
456
+ # add_log(f" Line {line_idx + 1} (local coords): ({lx1}, {ly1}) β†’ ({lx2}, {ly2}), size: {line_width}x{line_height}")
457
+
458
+ # # FIX 4: Translate line coordinates back to original image space
459
+ # # Account for padding offset
460
+ # global_lx1 = padded_rx1 + lx1
461
+ # global_ly1 = padded_ry1 + ly1
462
+ # global_lx2 = padded_rx1 + lx2
463
+ # global_ly2 = padded_ry1 + ly2
464
+
465
+ # # FIX 5: Validate coordinates are within image bounds
466
+ # global_lx1 = max(0, min(width, global_lx1))
467
+ # global_ly1 = max(0, min(height, global_ly1))
468
+ # global_lx2 = max(0, min(width, global_lx2))
469
+ # global_ly2 = max(0, min(height, global_ly2))
470
+
471
+ # add_log(f" Line {line_idx + 1} (global coords): ({global_lx1}, {global_ly1}) β†’ ({global_lx2}, {global_ly2})")
472
+
473
+ # # Draw RED box for Line
474
+ # draw.rectangle([global_lx1, global_ly1, global_lx2, global_ly2], outline="red", width=3)
475
+
476
+ # # OCR on the line crop from region_crop
477
+ # line_crop = region_crop.crop((lx1, ly1, lx2, ly2))
478
+
479
+ # if enable_debug_crops:
480
+ # line_crop.save(f"{debug_crops_dir}/region_{region_idx:02d}_line_{line_idx:02d}.png")
481
+
482
+ # text = run_trocr(line_crop, processor, trocr_model, device)
483
+ # add_log(f" Line {line_idx + 1} OCR: '{text}'")
484
+ # all_lines.append(text)
485
+
486
+ # # --- STRATEGY 2: Fallback to Full Page ---
487
+ # if not found_lines_in_regions:
488
+ # add_log("⚠️ Region detection yielded no lines. Switching to Fallback Strategy...", "WARNING")
489
+ # add_log("Strategy 2: Running line detection on full page")
490
+
491
+ # line_results = line_model(img, conf=0.2, imgsz=1024, verbose=False)
492
+ # lines_data = line_results[0].boxes.xyxy.cpu().numpy()
493
+ # num_lines = len(lines_data)
494
+ # add_log(f"βœ“ Fallback found {num_lines} line(s) on full page")
495
+
496
+ # if num_lines > 0:
497
+ # sorted_indices = lines_data[:, 1].argsort()
498
+
499
+ # for line_idx, idx in enumerate(sorted_indices):
500
+ # # FIX 6: Use round() consistently
501
+ # lx1, ly1, lx2, ly2 = map(round, lines_data[idx].tolist())
502
+
503
+ # line_width = lx2 - lx1
504
+ # line_height = ly2 - ly1
505
+
506
+ # add_log(f" Fallback Line {line_idx + 1}: ({lx1}, {ly1}) β†’ ({lx2}, {ly2}), size: {line_width}x{line_height}")
507
+
508
+ # # FIX 7: Validate coordinates
509
+ # lx1 = max(0, min(width, lx1))
510
+ # ly1 = max(0, min(height, ly1))
511
+ # lx2 = max(0, min(width, lx2))
512
+ # ly2 = max(0, min(height, ly2))
513
+
514
+ # # Draw RED box for Line (on full image)
515
+ # draw.rectangle([lx1, ly1, lx2, ly2], outline="red", width=3)
516
+
517
+ # line_crop = img.crop((lx1, ly1, lx2, ly2))
518
+
519
+ # if enable_debug_crops:
520
+ # line_crop.save(f"{debug_crops_dir}/fullpage_line_{line_idx:02d}.png")
521
+
522
+ # text = run_trocr(line_crop, processor, trocr_model, device)
523
+ # add_log(f" Fallback Line {line_idx + 1} OCR: '{text}'")
524
+ # all_lines.append(text)
525
+
526
+ # if not all_lines:
527
+ # add_log("Failed to detect any text lines in both strategies", "ERROR")
528
+ # return debug_img, "No text could be extracted.", "\n".join(log_output)
529
+
530
+ # add_log(f"βœ“ Success! Extracted {len(all_lines)} total line(s)")
531
+
532
+ # if enable_debug_crops:
533
+ # add_log(f"βœ“ Debug crops saved to {debug_crops_dir}/")
534
+
535
+ # final_text = '\n'.join(all_lines)
536
+
537
+ # return debug_img, final_text, "\n".join(log_output)
538
+
539
+ # except Exception as e:
540
+ # error_msg = f"Error processing image: {str(e)}"
541
+ # add_log(error_msg, "ERROR")
542
+ # logger.exception("Full error traceback:")
543
+ # return image, f"Error: {str(e)}", "\n".join(log_output)
544
+
545
+ # # Create Gradio interface
546
+ # demo = gr.Interface(
547
+ # fn=process_document,
548
+ # inputs=[
549
+ # gr.Image(type="pil", label="Upload Handwritten Document"),
550
+ # gr.Checkbox(label="Save debug crops to disk", value=False)
551
+ # ],
552
+ # outputs=[
553
+ # gr.Image(type="pil", label="Debug Visualization (Green=Region, Red=Lines)"),
554
+ # gr.Textbox(label="Extracted Text", lines=10),
555
+ # gr.Textbox(label="Processing Logs", lines=15)
556
+ # ],
557
+ # title="πŸ“ Handwritten Text Recognition (HTR) with Enhanced Debugging",
558
+ # description="""
559
+ # Upload an image of a handwritten document.
560
+
561
+ # **Visualization Key:**
562
+ # - 🟩 **Green Box:** The broad region identified as containing text (original bounds).
563
+ # - πŸŸ₯ **Red Box:** The specific line of text sent to the OCR engine (with coordinate validation).
564
+
565
+ # **Improvements:**
566
+ # - Fixed coordinate rounding (eliminates truncation errors)
567
+ # - Added 10px padding to region crops (reduces edge effects)
568
+ # - Coordinate validation (ensures all boxes are within image bounds)
569
+ # - Enhanced logging with detailed coordinate tracking
570
+ # - Optional debug crop saving
571
+ # """,
572
+ # flagging_mode="never",
573
+ # theme=gr.themes.Soft()
574
+ # )
575
+
576
+ # if __name__ == "__main__":
577
+ # logger.info("Launching Gradio interface...")
578
+ # demo.launch()
579
+
580
+
581
+
582
+
583
+
584
+
585
+
586
+
587
+
588
+
589
+
590
+
591
+
592
+
593
+
594
+
595
+
596
+
597
+
598
+
599
+
600
+
601
+ import gradio as gr
602
+ from ultralytics import YOLO
603
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
604
+ from PIL import Image, ImageDraw
605
+ import torch
606
+ import logging
607
+ import os
608
+ import warnings
609
+ import time
610
+ from datetime import datetime
611
+
612
+ # Suppress noisy logs
613
+ os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
614
+ warnings.filterwarnings('ignore')
615
+ logging.getLogger('transformers').setLevel(logging.ERROR)
616
+ logging.getLogger('ultralytics').setLevel(logging.WARNING) # still allow important warnings
617
+
618
+ # Setup clean logging
619
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-5s | %(message)s')
620
+ logger = logging.getLogger(__name__)
621
+
622
+ logger.info("Initializing models...")
623
+ device = "cuda" if torch.cuda.is_available() else "cpu"
624
+ logger.info(f"Device: {device}")
625
+
626
+ def load_with_retry(cls, name, token=None, retries=4, delay=6):
627
+ for attempt in range(1, retries + 1):
628
+ try:
629
+ logger.info(f"Loading {name} (attempt {attempt}/{retries})")
630
+ if "Processor" in str(cls):
631
+ return cls.from_pretrained(name, token=token)
632
+ return cls.from_pretrained(name, token=token).to(device)
633
+ except Exception as e:
634
+ logger.warning(f"Load failed: {e}")
635
+ if attempt < retries:
636
+ time.sleep(delay)
637
+ raise RuntimeError(f"Failed to load {name} after {retries} attempts")
638
+
639
+ try:
640
+ # Locate local YOLO weights
641
+ region_pt = 'regions.pt'
642
+ line_pt = 'lines.pt'
643
+
644
+ if not os.path.exists(region_pt):
645
+ for f in os.listdir('.'):
646
+ name = f.lower()
647
+ if 'region' in name and name.endswith('.pt'): region_pt = f
648
+ if 'line' in name and name.endswith('.pt'): line_pt = f
649
+
650
+ if not all(os.path.exists(p) for p in [region_pt, line_pt]):
651
+ raise FileNotFoundError("Could not find regions.pt and lines.pt (or similar)")
652
+
653
+ logger.info("Loading YOLO models...")
654
+ region_model = YOLO(region_pt)
655
+ line_model = YOLO(line_pt)
656
+ logger.info("YOLO models loaded")
657
+
658
+ hf_token = os.getenv("HF_TOKEN")
659
+ processor = load_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", hf_token)
660
+ trocr = load_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", hf_token)
661
+ logger.info("TrOCR loaded β†’ ready")
662
+
663
+ except Exception as e:
664
+ logger.error(f"Model loading failed: {e}", exc_info=True)
665
+ raise
666
+
667
+
668
+
669
+
670
+
671
+ def run_ocr(crop: Image.Image) -> str:
672
+ if crop.width < 20 or crop.height < 12:
673
+ return ""
674
+ pixels = processor(images=crop, return_tensors="pt").pixel_values.to(device)
675
+ ids = trocr.generate(pixels, max_new_tokens=128)
676
+ return processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
677
+
678
+
679
+ def process_document(
680
+ image,
681
+ enable_debug_crops: bool = False,
682
+ region_imgsz: int = 1024,
683
+ line_imgsz_base: int = 768,
684
+ conf_thresh: float = 0.25,
685
+ ):
686
+ start_ts = datetime.now().strftime("%H:%M:%S")
687
+ logs = []
688
+
689
+ def log(msg: str, level: str = "INFO"):
690
+ line = f"[{start_ts}] {level:5} {msg}"
691
+ logs.append(line)
692
+ if level == "ERROR":
693
+ logger.error(msg)
694
+ else:
695
+ logger.info(msg)
696
+
697
+ log("Start processing")
698
+
699
+ if image is None:
700
+ log("No image uploaded", "ERROR")
701
+ return None, "Upload an image", "\n".join(logs)
702
+
703
+ try:
704
+ # ── Prepare ─────────────────────────────────────────────────────────────
705
+ if not isinstance(image, Image.Image):
706
+ img = Image.open(image).convert("RGB")
707
+ else:
708
+ img = image.convert("RGB")
709
+
710
+ debug_img = img.copy()
711
+ draw = ImageDraw.Draw(debug_img)
712
+ w, h = img.size
713
+ log(f"Input image: {w} Γ— {h} px")
714
+
715
+ debug_dir = "debug_crops"
716
+ if enable_debug_crops:
717
+ os.makedirs(debug_dir, exist_ok=True)
718
+ log(f"Debug crops β†’ {debug_dir}/")
719
+
720
+ extracted = []
721
+ used_fallback = False
722
+
723
+ # ── Strategy 1: Region β†’ Lines ──────────────────────────────────────────
724
+ log(f"Running region detection (imgsz={region_imgsz}) …")
725
+ res_region = region_model(img, conf=conf_thresh, imgsz=region_imgsz, verbose=False)[0]
726
+ boxes_region = res_region.boxes
727
+
728
+ log(f"β†’ {len(boxes_region)} region candidate(s) (conf β‰₯ {conf_thresh})")
729
+
730
+ found_any_line = False
731
+
732
+ for i, box in enumerate(boxes_region, 1):
733
+ conf = float(box.conf)
734
+ xyxy = box.xyxy[0].cpu().tolist()
735
+ rx1, ry1, rx2, ry2 = map(round, xyxy)
736
+
737
+ rw, rh = rx2 - rx1, ry2 - ry1
738
+ log(f"Region {i}/{len(boxes_region)} conf={conf:.3f} {rx1},{ry1} β†’ {rx2},{ry2} ({rw}Γ—{rh})")
739
+
740
+ if rw < 60 or rh < 40:
741
+ log(f" β†’ skipped (too small)")
742
+ continue
743
+
744
+ # Padding
745
+ pad = 12
746
+ px1 = max(0, rx1 - pad)
747
+ py1 = max(0, ry1 - pad)
748
+ px2 = min(w, rx2 + pad)
749
+ py2 = min(h, ry2 + pad)
750
+
751
+ log(f" Padded crop: {px1},{py1} β†’ {px2},{py2}")
752
+
753
+ draw.rectangle((rx1, ry1, rx2, ry2), outline="green", width=4)
754
+
755
+ crop_region = img.crop((px1, py1, px2, py2))
756
+ crop_w, crop_h = crop_region.size
757
+
758
+ if enable_debug_crops:
759
+ crop_region.save(f"{debug_dir}/region_{i:02d}.png")
760
+
761
+ # Adaptive line imgsz: bigger crops β†’ bigger inference size
762
+ line_sz = line_imgsz_base
763
+ if max(crop_w, crop_h) > 1400:
764
+ line_sz = 1280
765
+ elif max(crop_w, crop_h) < 400:
766
+ line_sz = 640
767
+
768
+ log(f" β†’ line detection (imgsz={line_sz}) on {crop_w}Γ—{crop_h} crop …")
769
+ res_line = line_model(crop_region, conf=conf_thresh, imgsz=line_sz, verbose=False)[0]
770
+ line_boxes = res_line.boxes
771
+
772
+ log(f" β†’ {len(line_boxes)} line candidate(s)")
773
+
774
+ if len(line_boxes) == 0:
775
+ continue
776
+
777
+ found_any_line = True
778
+
779
+ # Sort top β†’ bottom
780
+ ys = line_boxes.xyxy[:, 1].cpu().numpy()
781
+ order = ys.argsort()
782
+
783
+ for j, idx in enumerate(order, 1):
784
+ conf_line = float(line_boxes.conf[idx])
785
+ lx1, ly1, lx2, ly2 = map(round, line_boxes.xyxy[idx].cpu().tolist())
786
+
787
+ lw, lh = lx2 - lx1, ly2 - ly1
788
+ log(f" Line {j} conf={conf_line:.3f} local {lx1},{ly1} β†’ {lx2},{ly2} ({lw}Γ—{lh})")
789
+
790
+ # Back to global coordinates
791
+ gx1 = px1 + lx1
792
+ gy1 = py1 + ly1
793
+ gx2 = px1 + lx2
794
+ gy2 = py1 + ly2
795
+
796
+ # Safety clamp
797
+ gx1, gy1 = max(0, gx1), max(0, gy1)
798
+ gx2, gy2 = min(w, gx2), min(h, gy2)
799
+
800
+ log(f" β†’ global {gx1},{gy1} β†’ {gx2},{gy2}")
801
+
802
+ draw.rectangle((gx1, gy1, gx2, gy2), outline="red", width=3)
803
+
804
+ line_crop = crop_region.crop((lx1, ly1, lx2, ly2))
805
+
806
+ if enable_debug_crops:
807
+ line_crop.save(f"{debug_dir}/reg{i:02d}_line{j:02d}_conf{conf_line:.2f}.png")
808
+
809
+ text = run_ocr(line_crop)
810
+ log(f" OCR β†’ '{text}'")
811
+ if text:
812
+ extracted.append(text)
813
+
814
+ # ── Strategy 2: Fallback full-page line detection ───────────────────────
815
+ if not found_any_line:
816
+ used_fallback = True
817
+ log("No lines found in regions β†’ fallback: full-page line detection")
818
+
819
+ line_sz = 1024 if max(w, h) > 1800 else line_imgsz_base
820
+ log(f"Full-page line detection (imgsz={line_sz}) …")
821
+
822
+ res = line_model(img, conf=conf_thresh, imgsz=line_sz, verbose=False)[0]
823
+ boxes = res.boxes
824
+
825
+ log(f"β†’ {len(boxes)} line(s) on full page")
826
+
827
+ if len(boxes) > 0:
828
+ ys = boxes.xyxy[:, 1].cpu().numpy()
829
+ order = ys.argsort()
830
+
831
+ for j, idx in enumerate(order, 1):
832
+ conf = float(boxes.conf[idx])
833
+ x1, y1, x2, y2 = map(round, boxes.xyxy[idx].cpu().tolist())
834
+ log(f" Line {j} conf={conf:.3f} {x1},{y1} β†’ {x2},{y2}")
835
+
836
+ draw.rectangle((x1,y1,x2,y2), outline="red", width=3)
837
+
838
+ crop = img.crop((x1,y1,x2,y2))
839
+
840
+ if enable_debug_crops:
841
+ crop.save(f"{debug_dir}/fallback_line{j:02d}_conf{conf:.2f}.png")
842
+
843
+ text = run_ocr(crop)
844
+ log(f" OCR β†’ '{text}'")
845
+ if text:
846
+ extracted.append(text)
847
+
848
+ # ── Finalize ────────────────────────────────────────────────────────────
849
+ if not extracted:
850
+ msg = "No readable text lines detected in either strategy"
851
+ log(msg, "WARNING")
852
+ return debug_img, msg, "\n".join(logs)
853
+
854
+ log(f"Success β€” extracted {len(extracted)} line(s)")
855
+ if enable_debug_crops:
856
+ log(f"Debug crops saved to {debug_dir}/")
857
+
858
+ return debug_img, "\n".join(extracted), "\n".join(logs)
859
+
860
+ except Exception as e:
861
+ log(f"Processing failed: {e}", "ERROR")
862
+ logger.exception("Traceback:")
863
+ return debug_img, f"Error: {str(e)}", "\n".join(logs)
864
+
865
+
866
+
867
+
868
+
869
+
870
+
871
+ demo = gr.Interface(
872
+ fn=process_document,
873
+ inputs=[
874
+ gr.Image(type="pil", label="Handwritten document"),
875
+ gr.Checkbox(label="Save debug crops", value=False),
876
+ gr.Slider(640, 1600, step=64, value=1024, label="Region detection size (imgsz)"),
877
+ gr.Slider(512, 1280, step=64, value=768, label="Base line detection size"),
878
+ gr.Slider(0.15, 0.5, step=0.05, value=0.25, label="Confidence threshold"),
879
+ ],
880
+ outputs=[
881
+ gr.Image(label="Debug (green=region, red=line)"),
882
+ gr.Textbox(label="Extracted Text", lines=10),
883
+ gr.Textbox(label="Detailed Logs (copy these if boxes look wrong)", lines=18),
884
+ ],
885
+ title="Handwritten Text β†’ OCR + Debug",
886
+ description=(
887
+ "Green = detected text regions β€’ Red = individual text lines sent to TrOCR\n\n"
888
+ "Copy the **Detailed Logs** if alignment still looks off β€” especially coords, sizes & confidences."
889
+ ),
890
+ theme=gr.themes.Soft(),
891
+ flagging_mode="never",
892
+ )
893
+
894
+ if __name__ == "__main__":
895
+ logger.info("Launching interface…")
896
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ #ultralytics
3
+ #transformers
4
+ #torch
5
+ #pillow
6
+ #numpy
7
+ #gradio
8
+
9
+ #pytz
10
+ #huggingface_hub
11
+
12
+
13
+ gradio
14
+ ultralytics
15
+ transformers
16
+ torch
17
+ torchvision
18
+ pillow
19
+ pytz