iammraat commited on
Commit
479016e
Β·
verified Β·
1 Parent(s): b067fca

Update app (1).py

Browse files
Files changed (1) hide show
  1. app (1).py +129 -841
app (1).py CHANGED
@@ -1,896 +1,184 @@
1
-
2
- # import gradio as gr
3
- # from ultralytics import YOLO
4
- # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
- # from PIL import Image, ImageDraw
6
- # import torch
7
- # import logging
8
- # from datetime import datetime
9
- # import os
10
- # import warnings
11
- # import time
12
-
13
- # # Suppress progress bar and unnecessary logs
14
- # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
15
- # warnings.filterwarnings('ignore')
16
- # logging.getLogger('transformers').setLevel(logging.ERROR)
17
- # logging.getLogger('ultralytics').setLevel(logging.ERROR)
18
-
19
- # # Setup logging
20
- # logging.basicConfig(
21
- # level=logging.INFO,
22
- # format='%(asctime)s - %(levelname)s - %(message)s'
23
- # )
24
- # logger = logging.getLogger(__name__)
25
-
26
- # logger.info("Starting model loading...")
27
- # device = "cuda" if torch.cuda.is_available() else "cpu"
28
- # logger.info(f"Using device: {device}")
29
-
30
- # # --- ROBUST MODEL LOADING FUNCTION ---
31
- # def load_model_with_retry(model_class, model_name, token=None, retries=5, delay=5):
32
- # """Attempts to load a HF model with retries to handle network timeouts."""
33
- # for attempt in range(retries):
34
- # try:
35
- # logger.info(f"Loading {model_name} (Attempt {attempt + 1}/{retries})...")
36
- # if "Processor" in str(model_class):
37
- # return model_class.from_pretrained(model_name, token=token)
38
- # else:
39
- # return model_class.from_pretrained(model_name, token=token).to(device)
40
- # except Exception as e:
41
- # logger.warning(f"Failed to load {model_name}: {e}")
42
- # if attempt < retries - 1:
43
- # logger.info(f"Retrying in {delay} seconds...")
44
- # time.sleep(delay)
45
- # else:
46
- # logger.error(f"Given up on loading {model_name} after {retries} attempts.")
47
- # raise e
48
-
49
- # try:
50
- # # 1. Load YOLO Models (Local Files)
51
- # region_model_file = 'regions.pt'
52
- # line_model_file = 'lines.pt'
53
-
54
- # # Simple check for local files
55
- # if not os.path.exists(region_model_file):
56
- # # Check current directory listing just in case
57
- # for file in os.listdir('.'):
58
- # if 'region' in file.lower() and file.endswith('.pt'): region_model_file = file
59
- # elif 'line' in file.lower() and file.endswith('.pt'): line_model_file = file
60
-
61
- # if not os.path.exists(region_model_file) or not os.path.exists(line_model_file):
62
- # raise FileNotFoundError("YOLO .pt files (regions.pt/lines.pt) not found.")
63
-
64
- # logger.info("Loading YOLO models...")
65
- # region_model = YOLO(region_model_file)
66
- # line_model = YOLO(line_model_file)
67
- # logger.info("βœ“ YOLO models loaded")
68
-
69
- # # 2. Load TrOCR with Retries
70
- # hf_token = os.getenv("HF_TOKEN")
71
-
72
- # processor = load_model_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", token=hf_token)
73
- # logger.info("βœ“ TrOCR processor loaded")
74
-
75
- # trocr_model = load_model_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", token=hf_token)
76
- # logger.info("βœ“ TrOCR model loaded")
77
-
78
- # logger.info("All models loaded successfully!")
79
-
80
- # except Exception as e:
81
- # logger.error(f"CRITICAL ERROR loading models: {str(e)}")
82
- # raise
83
-
84
- # # --- OCR HELPER ---
85
- # def run_trocr(image_slice, processor, model, device):
86
- # """Runs TrOCR on a single cropped image slice."""
87
- # pixel_values = processor(images=image_slice, return_tensors="pt").pixel_values.to(device)
88
- # generated_ids = model.generate(pixel_values)
89
- # return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
90
-
91
- # def process_document(image):
92
- # """Process uploaded document image and extract handwritten text with visualization."""
93
- # timestamp = datetime.now().strftime("%H:%M:%S")
94
- # log_output = []
95
-
96
- # def add_log(message, level="INFO"):
97
- # log_msg = f"[{timestamp}] {level}: {message}"
98
- # log_output.append(log_msg)
99
- # if level == "ERROR":
100
- # logger.error(message)
101
- # else:
102
- # logger.info(message)
103
-
104
- # add_log("Starting document processing")
105
-
106
- # if image is None:
107
- # add_log("No image provided", "ERROR")
108
- # return None, "Please upload an image", "\n".join(log_output)
109
-
110
- # try:
111
- # # Prepare Image
112
- # if not isinstance(image, Image.Image):
113
- # img = Image.open(image).convert("RGB")
114
- # else:
115
- # img = image.convert("RGB")
116
-
117
- # # Create a drawing context for the debug image
118
- # debug_img = img.copy()
119
- # draw = ImageDraw.Draw(debug_img)
120
-
121
- # width, height = img.size
122
- # add_log(f"Image size: {width}x{height} pixels")
123
-
124
- # all_lines = []
125
-
126
- # # --- STRATEGY 1: Region Detection ---
127
- # add_log("Strategy 1: Running region detection...")
128
- # region_results = region_model(img, conf=0.2, imgsz=1024, verbose=False)
129
- # regions = region_results[0].boxes
130
- # num_regions = len(regions)
131
- # add_log(f"βœ“ Found {num_regions} potential text region(s)")
132
-
133
- # found_lines_in_regions = False
134
-
135
- # if num_regions > 0:
136
- # for region_idx, region in enumerate(regions):
137
- # add_log(f"Processing region {region_idx + 1}/{num_regions}")
138
-
139
- # # Get coordinates
140
- # rx1, ry1, rx2, ry2 = map(int, region.xyxy[0])
141
-
142
- # # Filter small artifacts
143
- # if (rx2 - rx1) < 50 or (ry2 - ry1) < 50:
144
- # add_log(f" Skipping tiny artifact: {rx2-rx1}x{ry2-ry1} px")
145
- # continue
146
-
147
- # # Draw GREEN box for Region
148
- # draw.rectangle([rx1, ry1, rx2, ry2], outline="green", width=5)
149
-
150
- # # Crop Region
151
- # region_crop = img.crop((rx1, ry1, rx2, ry2))
152
-
153
- # # Detect lines in this region
154
- # line_results = line_model(region_crop, conf=0.2, imgsz=1024, verbose=False)
155
- # lines = line_results[0].boxes
156
- # num_lines = len(lines)
157
- # add_log(f" βœ“ Found {num_lines} line(s) in region")
158
-
159
- # if num_lines > 0:
160
- # found_lines_in_regions = True
161
-
162
- # # Sort lines by Y position
163
- # lines_sorted = sorted(lines, key=lambda b: b.xyxy[0][1])
164
-
165
- # for line_idx, line in enumerate(lines_sorted):
166
- # lx1, ly1, lx2, ly2 = map(int, line.xyxy[0])
167
-
168
- # # Translate line coordinates back to original image space for drawing
169
- # global_lx1 = rx1 + lx1
170
- # global_ly1 = ry1 + ly1
171
- # global_lx2 = rx1 + lx2
172
- # global_ly2 = ry1 + ly2
173
-
174
- # # Draw RED box for Line
175
- # draw.rectangle([global_lx1, global_ly1, global_lx2, global_ly2], outline="red", width=3)
176
-
177
- # # OCR
178
- # line_crop = region_crop.crop((lx1, ly1, lx2, ly2))
179
- # text = run_trocr(line_crop, processor, trocr_model, device)
180
- # add_log(f" Line {line_idx + 1}: '{text}'")
181
- # all_lines.append(text)
182
-
183
- # # --- STRATEGY 2: Fallback to Full Page ---
184
- # if not found_lines_in_regions:
185
- # add_log("⚠️ Region detection yielded no lines. Switching to Fallback Strategy...", "WARNING")
186
- # add_log("Strategy 2: Running line detection on full page")
187
-
188
- # line_results = line_model(img, conf=0.2, imgsz=1024, verbose=False)
189
- # lines = line_results[0].boxes
190
- # num_lines = len(lines)
191
- # add_log(f"βœ“ Fallback found {num_lines} line(s) on full page")
192
-
193
- # if num_lines > 0:
194
- # lines_sorted = sorted(lines, key=lambda b: b.xyxy[0][1])
195
-
196
- # for line_idx, line in enumerate(lines_sorted):
197
- # lx1, ly1, lx2, ly2 = map(int, line.xyxy[0])
198
-
199
- # # Draw RED box for Line (on full image)
200
- # draw.rectangle([lx1, ly1, lx2, ly2], outline="red", width=3)
201
-
202
- # line_crop = img.crop((lx1, ly1, lx2, ly2))
203
- # text = run_trocr(line_crop, processor, trocr_model, device)
204
- # add_log(f" Line {line_idx + 1}: '{text}'")
205
- # all_lines.append(text)
206
-
207
- # if not all_lines:
208
- # add_log("Failed to detect any text lines in both strategies", "ERROR")
209
- # return debug_img, "No text could be extracted.", "\n".join(log_output)
210
-
211
- # add_log(f"βœ“ Success! Extracted {len(all_lines)} total line(s)")
212
- # final_text = '\n'.join(all_lines)
213
-
214
- # return debug_img, final_text, "\n".join(log_output)
215
-
216
- # except Exception as e:
217
- # error_msg = f"Error processing image: {str(e)}"
218
- # add_log(error_msg, "ERROR")
219
- # logger.exception("Full error traceback:")
220
- # # Return the original image if debug creation failed
221
- # return image, f"Error: {str(e)}", "\n".join(log_output)
222
-
223
- # # Create Gradio interface
224
- # demo = gr.Interface(
225
- # fn=process_document,
226
- # inputs=gr.Image(type="pil", label="Upload Handwritten Document"),
227
- # outputs=[
228
- # gr.Image(type="pil", label="Debug Visualization (Green=Region, Red=Lines)"),
229
- # gr.Textbox(label="Extracted Text", lines=10),
230
- # gr.Textbox(label="Processing Logs", lines=15)
231
- # ],
232
- # title="πŸ“ Handwritten Text Recognition (HTR) with Debugging",
233
- # description="""
234
- # Upload an image of a handwritten document.
235
-
236
- # **Visualization Key:**
237
- # - 🟩 **Green Box:** The broad region identified as containing text.
238
- # - πŸŸ₯ **Red Box:** The specific line of text sent to the OCR engine.
239
- # """,
240
- # flagging_mode="never",
241
- # theme=gr.themes.Soft()
242
- # )
243
-
244
- # if __name__ == "__main__":
245
- # logger.info("Launching Gradio interface...")
246
- # demo.launch()
247
-
248
-
249
-
250
-
251
-
252
-
253
-
254
-
255
-
256
-
257
-
258
-
259
-
260
-
261
-
262
-
263
- # import gradio as gr
264
- # from ultralytics import YOLO
265
- # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
266
- # from PIL import Image, ImageDraw, ImageFont
267
- # import torch
268
- # import logging
269
- # from datetime import datetime
270
- # import os
271
- # import warnings
272
- # import time
273
-
274
- # # Suppress progress bar and unnecessary logs
275
- # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
276
- # warnings.filterwarnings('ignore')
277
- # logging.getLogger('transformers').setLevel(logging.ERROR)
278
- # logging.getLogger('ultralytics').setLevel(logging.ERROR)
279
-
280
- # # Setup logging
281
- # logging.basicConfig(
282
- # level=logging.INFO,
283
- # format='%(asctime)s - %(levelname)s - %(message)s'
284
- # )
285
- # logger = logging.getLogger(__name__)
286
-
287
- # logger.info("Starting model loading...")
288
- # device = "cuda" if torch.cuda.is_available() else "cpu"
289
- # logger.info(f"Using device: {device}")
290
-
291
- # # --- ROBUST MODEL LOADING FUNCTION ---
292
- # def load_model_with_retry(model_class, model_name, token=None, retries=5, delay=5):
293
- # """Attempts to load a HF model with retries to handle network timeouts."""
294
- # for attempt in range(retries):
295
- # try:
296
- # logger.info(f"Loading {model_name} (Attempt {attempt + 1}/{retries})...")
297
- # if "Processor" in str(model_class):
298
- # return model_class.from_pretrained(model_name, token=token)
299
- # else:
300
- # return model_class.from_pretrained(model_name, token=token).to(device)
301
- # except Exception as e:
302
- # logger.warning(f"Failed to load {model_name}: {e}")
303
- # if attempt < retries - 1:
304
- # logger.info(f"Retrying in {delay} seconds...")
305
- # time.sleep(delay)
306
- # else:
307
- # logger.error(f"Given up on loading {model_name} after {retries} attempts.")
308
- # raise e
309
-
310
- # try:
311
- # # 1. Load YOLO Models (Local Files)
312
- # region_model_file = 'regions.pt'
313
- # line_model_file = 'lines.pt'
314
-
315
- # # Simple check for local files
316
- # if not os.path.exists(region_model_file):
317
- # for file in os.listdir('.'):
318
- # if 'region' in file.lower() and file.endswith('.pt'): region_model_file = file
319
- # elif 'line' in file.lower() and file.endswith('.pt'): line_model_file = file
320
-
321
- # if not os.path.exists(region_model_file) or not os.path.exists(line_model_file):
322
- # raise FileNotFoundError("YOLO .pt files (regions.pt/lines.pt) not found.")
323
-
324
- # logger.info("Loading YOLO models...")
325
- # region_model = YOLO(region_model_file)
326
- # line_model = YOLO(line_model_file)
327
- # logger.info("βœ“ YOLO models loaded")
328
-
329
- # # 2. Load TrOCR with Retries
330
- # hf_token = os.getenv("HF_TOKEN")
331
-
332
- # processor = load_model_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", token=hf_token)
333
- # logger.info("βœ“ TrOCR processor loaded")
334
-
335
- # trocr_model = load_model_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", token=hf_token)
336
- # logger.info("βœ“ TrOCR model loaded")
337
-
338
- # logger.info("All models loaded successfully!")
339
-
340
- # except Exception as e:
341
- # logger.error(f"CRITICAL ERROR loading models: {str(e)}")
342
- # raise
343
-
344
- # # --- OCR HELPER ---
345
- # def run_trocr(image_slice, processor, model, device):
346
- # """Runs TrOCR on a single cropped image slice."""
347
- # pixel_values = processor(images=image_slice, return_tensors="pt").pixel_values.to(device)
348
- # generated_ids = model.generate(pixel_values)
349
- # return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
350
-
351
- # def process_document(image, enable_debug_crops=False):
352
- # """Process uploaded document image and extract handwritten text with visualization."""
353
- # timestamp = datetime.now().strftime("%H:%M:%S")
354
- # log_output = []
355
-
356
- # def add_log(message, level="INFO"):
357
- # log_msg = f"[{timestamp}] {level}: {message}"
358
- # log_output.append(log_msg)
359
- # if level == "ERROR":
360
- # logger.error(message)
361
- # else:
362
- # logger.info(message)
363
-
364
- # add_log("Starting document processing")
365
-
366
- # if image is None:
367
- # add_log("No image provided", "ERROR")
368
- # return None, "Please upload an image", "\n".join(log_output)
369
-
370
- # try:
371
- # # Prepare Image
372
- # if not isinstance(image, Image.Image):
373
- # img = Image.open(image).convert("RGB")
374
- # else:
375
- # img = image.convert("RGB")
376
-
377
- # # Create a drawing context for the debug image
378
- # debug_img = img.copy()
379
- # draw = ImageDraw.Draw(debug_img)
380
-
381
- # width, height = img.size
382
- # add_log(f"Image size: {width}x{height} pixels")
383
-
384
- # all_lines = []
385
- # debug_crops_dir = "debug_crops"
386
-
387
- # if enable_debug_crops:
388
- # os.makedirs(debug_crops_dir, exist_ok=True)
389
- # add_log(f"Debug crops will be saved to {debug_crops_dir}/")
390
-
391
- # # --- STRATEGY 1: Region Detection ---
392
- # add_log("Strategy 1: Running region detection...")
393
- # region_results = region_model(img, conf=0.2, imgsz=1024, verbose=False)
394
- # regions = region_results[0].boxes
395
- # num_regions = len(regions)
396
- # add_log(f"βœ“ Found {num_regions} potential text region(s)")
397
-
398
- # found_lines_in_regions = False
399
-
400
- # if num_regions > 0:
401
- # for region_idx, region in enumerate(regions):
402
- # add_log(f"Processing region {region_idx + 1}/{num_regions}")
403
-
404
- # # FIX 1: Use round() instead of int() to minimize precision loss
405
- # rx1, ry1, rx2, ry2 = map(round, region.xyxy[0].tolist())
406
-
407
- # # Calculate region dimensions
408
- # region_width = rx2 - rx1
409
- # region_height = ry2 - ry1
410
-
411
- # add_log(f" Region coords: ({rx1}, {ry1}) β†’ ({rx2}, {ry2}), size: {region_width}x{region_height}")
412
-
413
- # # Filter small artifacts
414
- # if region_width < 50 or region_height < 50:
415
- # add_log(f" Skipping tiny artifact: {region_width}x{region_height} px")
416
- # continue
417
-
418
- # # FIX 2: Add padding to region crops to avoid edge effects
419
- # padding = 10
420
- # padded_rx1 = max(0, rx1 - padding)
421
- # padded_ry1 = max(0, ry1 - padding)
422
- # padded_rx2 = min(width, rx2 + padding)
423
- # padded_ry2 = min(height, ry2 + padding)
424
-
425
- # add_log(f" Padded coords: ({padded_rx1}, {padded_ry1}) β†’ ({padded_rx2}, {padded_ry2})")
426
-
427
- # # Draw GREEN box for Region (original bounds, not padded)
428
- # draw.rectangle([rx1, ry1, rx2, ry2], outline="green", width=5)
429
-
430
- # # Crop Region with padding
431
- # region_crop = img.crop((padded_rx1, padded_ry1, padded_rx2, padded_ry2))
432
-
433
- # if enable_debug_crops:
434
- # region_crop.save(f"{debug_crops_dir}/region_{region_idx:02d}.png")
435
-
436
- # # Detect lines in this region
437
- # add_log(f" Running line detection on region crop ({region_crop.size[0]}x{region_crop.size[1]})...")
438
- # line_results = line_model(region_crop, conf=0.2, imgsz=1024, verbose=False)
439
- # lines_data = line_results[0].boxes.xyxy.cpu().numpy()
440
- # num_lines = len(lines_data)
441
- # add_log(f" βœ“ Found {num_lines} line(s) in region")
442
-
443
- # if num_lines > 0:
444
- # found_lines_in_regions = True
445
-
446
- # # Sort lines by Y position (index 1 of xyxy)
447
- # sorted_indices = lines_data[:, 1].argsort()
448
-
449
- # for line_idx, idx in enumerate(sorted_indices):
450
- # # FIX 3: Use round() for line coordinates too
451
- # lx1, ly1, lx2, ly2 = map(round, lines_data[idx].tolist())
452
-
453
- # line_width = lx2 - lx1
454
- # line_height = ly2 - ly1
455
-
456
- # add_log(f" Line {line_idx + 1} (local coords): ({lx1}, {ly1}) β†’ ({lx2}, {ly2}), size: {line_width}x{line_height}")
457
-
458
- # # FIX 4: Translate line coordinates back to original image space
459
- # # Account for padding offset
460
- # global_lx1 = padded_rx1 + lx1
461
- # global_ly1 = padded_ry1 + ly1
462
- # global_lx2 = padded_rx1 + lx2
463
- # global_ly2 = padded_ry1 + ly2
464
-
465
- # # FIX 5: Validate coordinates are within image bounds
466
- # global_lx1 = max(0, min(width, global_lx1))
467
- # global_ly1 = max(0, min(height, global_ly1))
468
- # global_lx2 = max(0, min(width, global_lx2))
469
- # global_ly2 = max(0, min(height, global_ly2))
470
-
471
- # add_log(f" Line {line_idx + 1} (global coords): ({global_lx1}, {global_ly1}) β†’ ({global_lx2}, {global_ly2})")
472
-
473
- # # Draw RED box for Line
474
- # draw.rectangle([global_lx1, global_ly1, global_lx2, global_ly2], outline="red", width=3)
475
-
476
- # # OCR on the line crop from region_crop
477
- # line_crop = region_crop.crop((lx1, ly1, lx2, ly2))
478
-
479
- # if enable_debug_crops:
480
- # line_crop.save(f"{debug_crops_dir}/region_{region_idx:02d}_line_{line_idx:02d}.png")
481
-
482
- # text = run_trocr(line_crop, processor, trocr_model, device)
483
- # add_log(f" Line {line_idx + 1} OCR: '{text}'")
484
- # all_lines.append(text)
485
-
486
- # # --- STRATEGY 2: Fallback to Full Page ---
487
- # if not found_lines_in_regions:
488
- # add_log("⚠️ Region detection yielded no lines. Switching to Fallback Strategy...", "WARNING")
489
- # add_log("Strategy 2: Running line detection on full page")
490
-
491
- # line_results = line_model(img, conf=0.2, imgsz=1024, verbose=False)
492
- # lines_data = line_results[0].boxes.xyxy.cpu().numpy()
493
- # num_lines = len(lines_data)
494
- # add_log(f"βœ“ Fallback found {num_lines} line(s) on full page")
495
-
496
- # if num_lines > 0:
497
- # sorted_indices = lines_data[:, 1].argsort()
498
-
499
- # for line_idx, idx in enumerate(sorted_indices):
500
- # # FIX 6: Use round() consistently
501
- # lx1, ly1, lx2, ly2 = map(round, lines_data[idx].tolist())
502
-
503
- # line_width = lx2 - lx1
504
- # line_height = ly2 - ly1
505
-
506
- # add_log(f" Fallback Line {line_idx + 1}: ({lx1}, {ly1}) β†’ ({lx2}, {ly2}), size: {line_width}x{line_height}")
507
-
508
- # # FIX 7: Validate coordinates
509
- # lx1 = max(0, min(width, lx1))
510
- # ly1 = max(0, min(height, ly1))
511
- # lx2 = max(0, min(width, lx2))
512
- # ly2 = max(0, min(height, ly2))
513
-
514
- # # Draw RED box for Line (on full image)
515
- # draw.rectangle([lx1, ly1, lx2, ly2], outline="red", width=3)
516
-
517
- # line_crop = img.crop((lx1, ly1, lx2, ly2))
518
-
519
- # if enable_debug_crops:
520
- # line_crop.save(f"{debug_crops_dir}/fullpage_line_{line_idx:02d}.png")
521
-
522
- # text = run_trocr(line_crop, processor, trocr_model, device)
523
- # add_log(f" Fallback Line {line_idx + 1} OCR: '{text}'")
524
- # all_lines.append(text)
525
-
526
- # if not all_lines:
527
- # add_log("Failed to detect any text lines in both strategies", "ERROR")
528
- # return debug_img, "No text could be extracted.", "\n".join(log_output)
529
-
530
- # add_log(f"βœ“ Success! Extracted {len(all_lines)} total line(s)")
531
-
532
- # if enable_debug_crops:
533
- # add_log(f"βœ“ Debug crops saved to {debug_crops_dir}/")
534
-
535
- # final_text = '\n'.join(all_lines)
536
-
537
- # return debug_img, final_text, "\n".join(log_output)
538
-
539
- # except Exception as e:
540
- # error_msg = f"Error processing image: {str(e)}"
541
- # add_log(error_msg, "ERROR")
542
- # logger.exception("Full error traceback:")
543
- # return image, f"Error: {str(e)}", "\n".join(log_output)
544
-
545
- # # Create Gradio interface
546
- # demo = gr.Interface(
547
- # fn=process_document,
548
- # inputs=[
549
- # gr.Image(type="pil", label="Upload Handwritten Document"),
550
- # gr.Checkbox(label="Save debug crops to disk", value=False)
551
- # ],
552
- # outputs=[
553
- # gr.Image(type="pil", label="Debug Visualization (Green=Region, Red=Lines)"),
554
- # gr.Textbox(label="Extracted Text", lines=10),
555
- # gr.Textbox(label="Processing Logs", lines=15)
556
- # ],
557
- # title="πŸ“ Handwritten Text Recognition (HTR) with Enhanced Debugging",
558
- # description="""
559
- # Upload an image of a handwritten document.
560
-
561
- # **Visualization Key:**
562
- # - 🟩 **Green Box:** The broad region identified as containing text (original bounds).
563
- # - πŸŸ₯ **Red Box:** The specific line of text sent to the OCR engine (with coordinate validation).
564
-
565
- # **Improvements:**
566
- # - Fixed coordinate rounding (eliminates truncation errors)
567
- # - Added 10px padding to region crops (reduces edge effects)
568
- # - Coordinate validation (ensures all boxes are within image bounds)
569
- # - Enhanced logging with detailed coordinate tracking
570
- # - Optional debug crop saving
571
- # """,
572
- # flagging_mode="never",
573
- # theme=gr.themes.Soft()
574
- # )
575
-
576
- # if __name__ == "__main__":
577
- # logger.info("Launching Gradio interface...")
578
- # demo.launch()
579
-
580
-
581
-
582
-
583
-
584
-
585
-
586
-
587
-
588
-
589
-
590
-
591
-
592
-
593
-
594
-
595
-
596
-
597
-
598
-
599
-
600
-
601
  import gradio as gr
602
  from ultralytics import YOLO
603
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
604
- from PIL import Image, ImageDraw
605
  import torch
606
  import logging
607
  import os
608
- import warnings
609
- import time
610
  from datetime import datetime
611
 
612
- # Suppress noisy logs
613
  os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
614
- warnings.filterwarnings('ignore')
615
- logging.getLogger('transformers').setLevel(logging.ERROR)
616
- logging.getLogger('ultralytics').setLevel(logging.WARNING) # still allow important warnings
617
 
618
- # Setup clean logging
619
- logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-5s | %(message)s')
 
 
620
  logger = logging.getLogger(__name__)
621
 
622
- logger.info("Initializing models...")
 
623
  device = "cuda" if torch.cuda.is_available() else "cpu"
624
  logger.info(f"Device: {device}")
625
 
626
- def load_with_retry(cls, name, token=None, retries=4, delay=6):
627
- for attempt in range(1, retries + 1):
628
- try:
629
- logger.info(f"Loading {name} (attempt {attempt}/{retries})")
630
- if "Processor" in str(cls):
631
- return cls.from_pretrained(name, token=token)
632
- return cls.from_pretrained(name, token=token).to(device)
633
- except Exception as e:
634
- logger.warning(f"Load failed: {e}")
635
- if attempt < retries:
636
- time.sleep(delay)
637
- raise RuntimeError(f"Failed to load {name} after {retries} attempts")
638
-
639
  try:
640
- # Locate local YOLO weights
641
  region_pt = 'regions.pt'
642
- line_pt = 'lines.pt'
643
-
644
  if not os.path.exists(region_pt):
645
  for f in os.listdir('.'):
646
  name = f.lower()
647
- if 'region' in name and name.endswith('.pt'): region_pt = f
648
- if 'line' in name and name.endswith('.pt'): line_pt = f
649
-
650
- if not all(os.path.exists(p) for p in [region_pt, line_pt]):
651
- raise FileNotFoundError("Could not find regions.pt and lines.pt (or similar)")
652
 
653
- logger.info("Loading YOLO models...")
654
- region_model = YOLO(region_pt)
655
- line_model = YOLO(line_pt)
656
- logger.info("YOLO models loaded")
657
 
658
- hf_token = os.getenv("HF_TOKEN")
659
- processor = load_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", hf_token)
660
- trocr = load_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", hf_token)
661
- logger.info("TrOCR loaded β†’ ready")
662
 
663
  except Exception as e:
664
- logger.error(f"Model loading failed: {e}", exc_info=True)
665
  raise
666
 
667
 
668
-
669
-
670
-
671
- def run_ocr(crop: Image.Image) -> str:
672
- if crop.width < 20 or crop.height < 12:
673
- return ""
674
- pixels = processor(images=crop, return_tensors="pt").pixel_values.to(device)
675
- ids = trocr.generate(pixels, max_new_tokens=128)
676
- return processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
677
-
678
-
679
- def process_document(
680
  image,
681
- enable_debug_crops: bool = False,
682
- region_imgsz: int = 1024,
683
- line_imgsz_base: int = 768,
684
  conf_thresh: float = 0.25,
 
 
 
 
 
685
  ):
686
- start_ts = datetime.now().strftime("%H:%M:%S")
687
- logs = []
688
-
689
- def log(msg: str, level: str = "INFO"):
690
- line = f"[{start_ts}] {level:5} {msg}"
691
- logs.append(line)
692
- if level == "ERROR":
693
- logger.error(msg)
694
- else:
695
- logger.info(msg)
696
-
697
- log("Start processing")
698
 
699
  if image is None:
700
- log("No image uploaded", "ERROR")
701
- return None, "Upload an image", "\n".join(logs)
702
-
703
- try:
704
- # ── Prepare ─────────────────────────────────────────────────────────────
705
- if not isinstance(image, Image.Image):
706
- img = Image.open(image).convert("RGB")
707
- else:
708
- img = image.convert("RGB")
709
-
710
- debug_img = img.copy()
711
- draw = ImageDraw.Draw(debug_img)
712
- w, h = img.size
713
- log(f"Input image: {w} Γ— {h} px")
714
-
715
- debug_dir = "debug_crops"
716
- if enable_debug_crops:
717
- os.makedirs(debug_dir, exist_ok=True)
718
- log(f"Debug crops β†’ {debug_dir}/")
719
 
720
- extracted = []
721
- used_fallback = False
 
 
 
722
 
723
- # ── Strategy 1: Region β†’ Lines ──────────────────────────────────────────
724
- log(f"Running region detection (imgsz={region_imgsz}) …")
725
- res_region = region_model(img, conf=conf_thresh, imgsz=region_imgsz, verbose=False)[0]
726
- boxes_region = res_region.boxes
727
 
728
- log(f"β†’ {len(boxes_region)} region candidate(s) (conf β‰₯ {conf_thresh})")
 
729
 
730
- found_any_line = False
731
-
732
- for i, box in enumerate(boxes_region, 1):
733
- conf = float(box.conf)
734
- xyxy = box.xyxy[0].cpu().tolist()
735
- rx1, ry1, rx2, ry2 = map(round, xyxy)
736
-
737
- rw, rh = rx2 - rx1, ry2 - ry1
738
- log(f"Region {i}/{len(boxes_region)} conf={conf:.3f} {rx1},{ry1} β†’ {rx2},{ry2} ({rw}Γ—{rh})")
739
-
740
- if rw < 60 or rh < 40:
741
- log(f" β†’ skipped (too small)")
742
- continue
743
-
744
- # Padding
745
- pad = 12
746
- px1 = max(0, rx1 - pad)
747
- py1 = max(0, ry1 - pad)
748
- px2 = min(w, rx2 + pad)
749
- py2 = min(h, ry2 + pad)
750
-
751
- log(f" Padded crop: {px1},{py1} β†’ {px2},{py2}")
752
-
753
- draw.rectangle((rx1, ry1, rx2, ry2), outline="green", width=4)
754
-
755
- crop_region = img.crop((px1, py1, px2, py2))
756
- crop_w, crop_h = crop_region.size
757
-
758
- if enable_debug_crops:
759
- crop_region.save(f"{debug_dir}/region_{i:02d}.png")
760
-
761
- # Adaptive line imgsz: bigger crops β†’ bigger inference size
762
- line_sz = line_imgsz_base
763
- if max(crop_w, crop_h) > 1400:
764
- line_sz = 1280
765
- elif max(crop_w, crop_h) < 400:
766
- line_sz = 640
767
-
768
- log(f" β†’ line detection (imgsz={line_sz}) on {crop_w}Γ—{crop_h} crop …")
769
- res_line = line_model(crop_region, conf=conf_thresh, imgsz=line_sz, verbose=False)[0]
770
- line_boxes = res_line.boxes
771
-
772
- log(f" β†’ {len(line_boxes)} line candidate(s)")
773
-
774
- if len(line_boxes) == 0:
775
- continue
776
-
777
- found_any_line = True
778
-
779
- # Sort top β†’ bottom
780
- ys = line_boxes.xyxy[:, 1].cpu().numpy()
781
  order = ys.argsort()
782
 
783
- for j, idx in enumerate(order, 1):
784
- conf_line = float(line_boxes.conf[idx])
785
- lx1, ly1, lx2, ly2 = map(round, line_boxes.xyxy[idx].cpu().tolist())
786
-
787
- lw, lh = lx2 - lx1, ly2 - ly1
788
- log(f" Line {j} conf={conf_line:.3f} local {lx1},{ly1} β†’ {lx2},{ly2} ({lw}Γ—{lh})")
789
-
790
- # Back to global coordinates
791
- gx1 = px1 + lx1
792
- gy1 = py1 + ly1
793
- gx2 = px1 + lx2
794
- gy2 = py1 + ly2
795
-
796
- # Safety clamp
797
- gx1, gy1 = max(0, gx1), max(0, gy1)
798
- gx2, gy2 = min(w, gx2), min(h, gy2)
799
-
800
- log(f" β†’ global {gx1},{gy1} β†’ {gx2},{gy2}")
801
-
802
- draw.rectangle((gx1, gy1, gx2, gy2), outline="red", width=3)
803
-
804
- line_crop = crop_region.crop((lx1, ly1, lx2, ly2))
805
-
806
- if enable_debug_crops:
807
- line_crop.save(f"{debug_dir}/reg{i:02d}_line{j:02d}_conf{conf_line:.2f}.png")
808
-
809
- text = run_ocr(line_crop)
810
- log(f" OCR β†’ '{text}'")
811
- if text:
812
- extracted.append(text)
813
-
814
- # ── Strategy 2: Fallback full-page line detection ───────────────────────
815
- if not found_any_line:
816
- used_fallback = True
817
- log("No lines found in regions β†’ fallback: full-page line detection")
818
-
819
- line_sz = 1024 if max(w, h) > 1800 else line_imgsz_base
820
- log(f"Full-page line detection (imgsz={line_sz}) …")
821
-
822
- res = line_model(img, conf=conf_thresh, imgsz=line_sz, verbose=False)[0]
823
- boxes = res.boxes
824
-
825
- log(f"β†’ {len(boxes)} line(s) on full page")
826
-
827
- if len(boxes) > 0:
828
- ys = boxes.xyxy[:, 1].cpu().numpy()
829
- order = ys.argsort()
830
-
831
- for j, idx in enumerate(order, 1):
832
- conf = float(boxes.conf[idx])
833
- x1, y1, x2, y2 = map(round, boxes.xyxy[idx].cpu().tolist())
834
- log(f" Line {j} conf={conf:.3f} {x1},{y1} β†’ {x2},{y2}")
835
-
836
- draw.rectangle((x1,y1,x2,y2), outline="red", width=3)
837
-
838
- crop = img.crop((x1,y1,x2,y2))
839
-
840
- if enable_debug_crops:
841
- crop.save(f"{debug_dir}/fallback_line{j:02d}_conf{conf:.2f}.png")
842
-
843
- text = run_ocr(crop)
844
- log(f" OCR β†’ '{text}'")
845
- if text:
846
- extracted.append(text)
847
-
848
- # ── Finalize ────────────────────────────────────────────────────────────
849
- if not extracted:
850
- msg = "No readable text lines detected in either strategy"
851
- log(msg, "WARNING")
852
- return debug_img, msg, "\n".join(logs)
853
 
854
- log(f"Success β€” extracted {len(extracted)} line(s)")
855
- if enable_debug_crops:
856
- log(f"Debug crops saved to {debug_dir}/")
857
 
858
- return debug_img, "\n".join(extracted), "\n".join(logs)
859
 
860
  except Exception as e:
861
- log(f"Processing failed: {e}", "ERROR")
862
- logger.exception("Traceback:")
863
- return debug_img, f"Error: {str(e)}", "\n".join(logs)
864
-
865
-
866
-
867
-
868
-
869
 
870
 
871
  demo = gr.Interface(
872
- fn=process_document,
873
  inputs=[
874
- gr.Image(type="pil", label="Handwritten document"),
875
- gr.Checkbox(label="Save debug crops", value=False),
876
- gr.Slider(640, 1600, step=64, value=1024, label="Region detection size (imgsz)"),
877
- gr.Slider(512, 1280, step=64, value=768, label="Base line detection size"),
878
- gr.Slider(0.15, 0.5, step=0.05, value=0.25, label="Confidence threshold"),
 
 
879
  ],
880
  outputs=[
881
- gr.Image(label="Debug (green=region, red=line)"),
882
- gr.Textbox(label="Extracted Text", lines=10),
883
- gr.Textbox(label="Detailed Logs (copy these if boxes look wrong)", lines=18),
884
  ],
885
- title="Handwritten Text β†’ OCR + Debug",
886
  description=(
887
- "Green = detected text regions β€’ Red = individual text lines sent to TrOCR\n\n"
888
- "Copy the **Detailed Logs** if alignment still looks off β€” especially coords, sizes & confidences."
 
 
 
889
  ),
890
  theme=gr.themes.Soft(),
891
- flagging_mode="never",
892
  )
893
 
894
  if __name__ == "__main__":
895
- logger.info("Launching interface…")
896
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
+ from PIL import Image, ImageDraw, ImageFont
 
4
  import torch
5
  import logging
6
  import os
 
 
7
  from datetime import datetime
8
 
9
+ # ── Quiet startup ───────────────────────────────────────────────────────
10
  os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
11
+ logging.getLogger('ultralytics').setLevel(logging.WARNING)
 
 
12
 
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s | %(level)-5s | %(message)s'
16
+ )
17
  logger = logging.getLogger(__name__)
18
 
19
+ logger.info("Initializing region detector...")
20
+
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  logger.info(f"Device: {device}")
23
 
24
+ # ── Load YOLO ───────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
25
  try:
 
26
  region_pt = 'regions.pt'
 
 
27
  if not os.path.exists(region_pt):
28
  for f in os.listdir('.'):
29
  name = f.lower()
30
+ if name.endswith('.pt') and 'region' in name:
31
+ region_pt = f
32
+ break
 
 
33
 
34
+ if not os.path.exists(region_pt):
35
+ raise FileNotFoundError("No regions.pt (or similar *.pt) found in current directory")
 
 
36
 
37
+ logger.info(f"Loading model: {region_pt}")
38
+ model = YOLO(region_pt)
39
+ logger.info("Region detector loaded")
 
40
 
41
  except Exception as e:
42
+ logger.error(f"Model loading failed β†’ {e}", exc_info=True)
43
  raise
44
 
45
 
46
+ def visualize_regions(
 
 
 
 
 
 
 
 
 
 
 
47
  image,
 
 
 
48
  conf_thresh: float = 0.25,
49
+ min_size: int = 60,
50
+ padding: int = 0,
51
+ show_labels: bool = True,
52
+ save_debug_crops: bool = False,
53
+ imgsz: int = 1024,
54
  ):
55
+ start = datetime.now().strftime("%H:%M:%S")
56
+ logs = [f"[{start}] Processing started"]
 
 
 
 
 
 
 
 
 
 
57
 
58
  if image is None:
59
+ logs.append("No image uploaded")
60
+ return None, "\n".join(logs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Load & convert
63
+ if isinstance(image, str):
64
+ img = Image.open(image).convert("RGB")
65
+ else:
66
+ img = image.convert("RGB")
67
 
68
+ w, h = img.size
69
+ logs.append(f"Image size: {w} Γ— {h}")
 
 
70
 
71
+ debug_img = img.copy()
72
+ draw = ImageDraw.Draw(debug_img)
73
 
74
+ try:
75
+ # Font for drawing labels (fallback to default)
76
+ try:
77
+ font = ImageFont.truetype("arial.ttf", 18)
78
+ except:
79
+ font = ImageFont.load_default()
80
+
81
+ # ── Run detection ───────────────────────────────────────────────
82
+ results = model(
83
+ img,
84
+ conf=conf_thresh,
85
+ imgsz=imgsz,
86
+ verbose=False
87
+ )[0]
88
+
89
+ boxes = results.boxes
90
+ logs.append(f"Detected {len(boxes)} region candidate(s)")
91
+
92
+ kept = 0
93
+
94
+ # Sort top β†’ bottom
95
+ if len(boxes) > 0:
96
+ ys = boxes.xyxy[:, 1].cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  order = ys.argsort()
98
 
99
+ for idx in order:
100
+ box = boxes[idx]
101
+ conf = float(box.conf)
102
+ if conf < conf_thresh:
103
+ continue
104
+
105
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
106
+ bw, bh = x2 - x1, y2 - y1
107
+
108
+ if bw < min_size or bh < min_size:
109
+ continue
110
+
111
+ # Optional padding (mostly for crop saving)
112
+ px1 = max(0, x1 - padding)
113
+ py1 = max(0, y1 - padding)
114
+ px2 = min(w, x2 + padding)
115
+ py2 = min(h, y2 + padding)
116
+
117
+ # Draw box
118
+ draw.rectangle((x1, y1, x2, y2), outline="lime", width=3)
119
+
120
+ if show_labels:
121
+ label = f"conf {conf:.2f} {bw}Γ—{bh}"
122
+ tw, th = draw.textbbox((0,0), label, font=font)[2:]
123
+ draw.rectangle(
124
+ (x1, y1 - th - 4, x1 + tw + 8, y1),
125
+ fill=(0, 180, 0, 160)
126
+ )
127
+ draw.text((x1 + 4, y1 - th - 2), label, fill="white", font=font)
128
+
129
+ kept += 1
130
+
131
+ # Optional: save individual crops
132
+ if save_debug_crops:
133
+ os.makedirs("debug_regions", exist_ok=True)
134
+ crop = img.crop((px1, py1, px2, py2))
135
+ fname = f"debug_regions/r{kept:02d}_conf{conf:.2f}_{bw}x{bh}.png"
136
+ crop.save(fname)
137
+ logs.append(f"Saved crop β†’ {fname}")
138
+
139
+ if kept == 0:
140
+ msg = f"No regions kept after filters (conf β‰₯ {conf_thresh}, size β‰₯ {min_size}px)"
141
+ logs.append(msg)
142
+ else:
143
+ logs.append(f"Visualized {kept} region(s)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ logs.append("Finished.")
 
 
146
 
147
+ return debug_img, "\n".join(logs)
148
 
149
  except Exception as e:
150
+ logs.append(f"Error during inference: {str(e)}")
151
+ logger.exception("Inference failed")
152
+ return debug_img, "\n".join(logs)
 
 
 
 
 
153
 
154
 
155
  demo = gr.Interface(
156
+ fn=visualize_regions,
157
  inputs=[
158
+ gr.Image(type="pil", label="Upload image (handwritten document)"),
159
+ gr.Slider(0.10, 0.60, step=0.02, value=0.25, label="Confidence threshold"),
160
+ gr.Slider(30, 300, step=10, value=60, label="Minimum region width/height (px)"),
161
+ gr.Slider(0, 40, step=4, value=0, label="Padding around box (for crops only)"),
162
+ gr.Checkbox(label="Draw confidence + size labels on boxes", value=True),
163
+ gr.Checkbox(label="Save individual region crops to debug_regions/", value=False),
164
+ gr.Slider(640, 1280, step=64, value=1024, label="Inference image size (imgsz)"),
165
  ],
166
  outputs=[
167
+ gr.Image(label="Detected text regions (green boxes)"),
168
+ gr.Textbox(label="Log / debug info", lines=14),
 
169
  ],
170
+ title="Region Detector Debug View",
171
  description=(
172
+ "Only shows what the region YOLO model sees.\n\n"
173
+ "β€’ Green boxes = detected text regions\n"
174
+ "β€’ Tune confidence and min size until boxes look reasonable\n"
175
+ "β€’ Use logs to see exact confidences and sizes\n"
176
+ "β€’ Save crops if you want to manually check what is being detected"
177
  ),
178
  theme=gr.themes.Soft(),
179
+ allow_flagging="never",
180
  )
181
 
182
  if __name__ == "__main__":
183
+ logger.info("Launching debug interface...")
184
  demo.launch()