iammraat commited on
Commit
adb25fe
·
verified ·
1 Parent(s): cd5090e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -104
app.py CHANGED
@@ -323,133 +323,94 @@ from PIL import Image
323
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
324
  from craft_text_detector import Craft
325
 
326
- # --- THE ULTIMATE MONKEY PATCH ---
327
  import craft_text_detector.craft_utils as craft_utils_module
328
 
329
  def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h):
330
- if not polys or len(polys) == 0:
331
- return []
332
-
333
- adjusted_polys = []
334
  for poly in polys:
335
- try:
336
- # Convert to numpy and check if it's actually a coordinate list
337
- p = np.array(poly).astype(np.float32)
338
-
339
- # If p is empty or just a single point/scalar, skip it
340
- if p.ndim != 2 or p.shape[0] == 0:
341
- continue
342
-
343
- # Scale coordinates
344
- p[:, 0] *= ratio_w
345
- p[:, 1] *= ratio_h
346
- adjusted_polys.append(p)
347
- except (IndexError, TypeError, ValueError):
348
- # If anything goes wrong with a specific noise-box, just skip it
349
- continue
350
-
351
- return adjusted_polys
352
 
353
- # Apply the patch to the library in memory
354
  craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
355
- # ----------------------------
356
 
357
- # Device Setup
358
  device = "cuda" if torch.cuda.is_available() else "cpu"
359
 
360
  # Load Models
361
- print("Loading TrOCR...")
362
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
363
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-handwritten")
364
- model.to(device).eval()
365
-
366
- print("Loading CRAFT...")
367
- # We use crop_type="box" for clean rectangles
368
  craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
369
 
370
- def get_sorted_boxes(boxes):
371
- """Sorts detected boxes into reading order: top-to-bottom, left-to-right."""
372
- items = []
373
- for box in boxes:
374
- # Avoid empty boxes
375
- if box is None or len(box) == 0:
376
- continue
377
- cx = np.mean(box[:, 0])
378
- cy = np.mean(box[:, 1])
379
- items.append((cy, cx, box))
380
-
381
- # Sort by Y (line grouping) then X
382
- items.sort(key=lambda x: (int(x[0] // 30), x[1]))
383
- return [b for _, _, b in items]
384
-
385
  def process_page(image):
386
- if image is None:
387
- return None, "Please upload an image."
388
-
389
- # Convert PIL to standard RGB format
390
- image_rgb = image.convert("RGB")
391
- image_np = np.array(image_rgb)
 
 
 
 
392
 
393
- # 1. Run Detection (Patched function handles coordinate mapping)
394
- prediction = craft.detect_text(image_np)
 
395
  boxes = prediction.get("boxes", [])
396
 
397
- if not boxes or len(boxes) == 0:
398
- return image_rgb, "No text detected."
399
 
400
- # 2. Sort and Draw
401
- sorted_boxes = get_sorted_boxes(boxes)
402
- annotated = image_np.copy()
403
- transcriptions = []
 
 
 
 
 
404
 
405
- for box in sorted_boxes:
406
- # Cast to integer for CV2 and slicing
407
- box_int = box.astype(np.int32)
408
 
409
- # 3. Draw on the visualization image
410
- cv2.polylines(annotated, [box_int], True, (255, 0, 0), 2)
411
 
412
- # 4. Extract Crop for OCR
413
- # Get axis-aligned bounding box from points
414
- x_min, y_min = np.min(box_int, axis=0)
415
- x_max, y_max = np.max(box_int, axis=0)
416
 
417
- # Keep within image dimensions
418
  x_min, y_min = max(0, x_min), max(0, y_min)
419
- x_max, y_max = min(image_np.shape[1], x_max), min(image_np.shape[0], y_max)
420
 
421
- # Skip boxes that are too small to contain a character
422
- if (x_max - x_min) < 10 or (y_max - y_min) < 10:
423
- continue
424
-
425
- crop_region = image_np[y_min:y_max, x_min:x_max]
426
- crop_pil = Image.fromarray(crop_region)
427
 
428
- # 5. Inference with TrOCR
 
429
  with torch.no_grad():
430
- pixel_values = processor(images=crop_pil, return_tensors="pt").pixel_values.to(device)
431
- generated_ids = model.generate(pixel_values)
432
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
433
-
434
- if text.strip():
435
- transcriptions.append(text)
436
-
437
- # Output the result
438
- final_text = " ".join(transcriptions)
439
- return Image.fromarray(annotated), final_text
440
-
441
- # Gradio Interface
442
- demo = gr.Interface(
443
- fn=process_page,
444
- inputs=gr.Image(type="pil", label="Upload Page Image"),
445
- outputs=[
446
- gr.Image(label="Detection Visualization"),
447
- gr.Textbox(label="Transcribed Text", lines=15)
448
- ],
449
- title="✍️ Full-Page Handwritten Recognition",
450
- description="Combines CRAFT detection with TrOCR recognition. Use high-contrast images for best results.",
451
- theme="soft"
452
- )
453
-
454
- if __name__ == "__main__":
455
- demo.launch()
 
323
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
324
  from craft_text_detector import Craft
325
 
326
+ # --- DEFENSIVE MONKEY PATCH ---
327
  import craft_text_detector.craft_utils as craft_utils_module
328
 
329
  def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h):
330
+ # This ensures coordinates are scaled correctly and remain as float32
331
+ if not polys: return []
332
+ adjusted = []
 
333
  for poly in polys:
334
+ p = np.array(poly).reshape(-1, 2).astype(np.float32)
335
+ p[:, 0] *= ratio_w
336
+ p[:, 1] *= ratio_h
337
+ adjusted.append(p)
338
+ return adjusted
 
 
 
 
 
 
 
 
 
 
 
 
339
 
 
340
  craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
341
+ # ------------------------------
342
 
 
343
  device = "cuda" if torch.cuda.is_available() else "cpu"
344
 
345
  # Load Models
 
346
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
347
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-handwritten").to(device).eval()
 
 
 
 
348
  craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def process_page(image):
351
+ if image is None: return None, "No image."
352
+
353
+ # 1. FORCE RESIZE for coordinate 1:1 mapping
354
+ # We resize to 1280 width to match CRAFT's default canvas
355
+ base_width = 1280
356
+ w_percent = (base_width / float(image.size[0]))
357
+ h_size = int((float(image.size[1]) * float(w_percent)))
358
+ image = image.resize((base_width, h_size), Image.Resampling.LANCZOS)
359
+
360
+ img_np = np.array(image.convert("RGB"))
361
 
362
+ # 2. DETECT
363
+ # Because we resized the image to 1280, ratio_w/h will be ~1.0
364
+ prediction = craft.detect_text(img_np)
365
  boxes = prediction.get("boxes", [])
366
 
367
+ if not boxes: return image, "No text found."
 
368
 
369
+ # 3. SORT (Improved line grouping)
370
+ # We group items within 30 pixels of each other vertically as a single 'line'
371
+ items = []
372
+ for box in boxes:
373
+ items.append({'cy': np.mean(box[:, 1]), 'cx': np.mean(box[:, 0]), 'box': box})
374
+ items.sort(key=lambda x: (int(x['cy'] // 30), x['cx']))
375
+
376
+ annotated = img_np.copy()
377
+ full_text = []
378
 
379
+ # 4. RECOGNIZE
380
+ for item in items:
381
+ box = item['box'].astype(np.int32)
382
 
383
+ # Draw on image
384
+ cv2.polylines(annotated, [box], True, (255, 0, 0), 2)
385
 
386
+ # Crop
387
+ x_min, y_min = np.min(box, axis=0)
388
+ x_max, y_max = np.max(box, axis=0)
 
389
 
390
+ # Clamp to image boundaries
391
  x_min, y_min = max(0, x_min), max(0, y_min)
392
+ x_max, y_max = min(img_np.shape[1], x_max), min(img_np.shape[0], y_max)
393
 
394
+ if (x_max - x_min) < 5 or (y_max - y_min) < 5: continue
 
 
 
 
 
395
 
396
+ crop = Image.fromarray(img_np[y_min:y_max, x_min:x_max])
397
+
398
  with torch.no_grad():
399
+ pixel_values = processor(images=crop, return_tensors="pt").pixel_values.to(device)
400
+ out_ids = model.generate(pixel_values)
401
+ txt = processor.batch_decode(out_ids, skip_special_tokens=True)[0]
402
+ if txt.strip(): full_text.append(txt)
403
+
404
+ return Image.fromarray(annotated), " ".join(full_text)
405
+
406
+ # UI
407
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
408
+ gr.Markdown("# 🚀 Final Fix: Full-Page OCR")
409
+ with gr.Row():
410
+ input_i = gr.Image(type="pil", label="Input")
411
+ output_i = gr.Image(label="Detections (Scale Fixed)")
412
+ output_t = gr.Textbox(label="Result", lines=10)
413
+ btn = gr.Button("Transcribe", variant="primary")
414
+ btn.click(process_page, input_i, [output_i, output_t])
415
+
416
+ demo.launch()