PrabhatGupta786 commited on
Commit
a80d44c
·
verified ·
1 Parent(s): de5c1ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -8
app.py CHANGED
@@ -1,17 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def full_pipeline(input_img):
2
  if input_img is None:
3
  return "Please upload an image."
4
 
5
  lines = get_lines_from_image(input_img)
6
  if not lines:
7
- return "No text lines detected. Try a clearer image."
8
 
9
  final_transcript = []
10
 
11
- # Process one line at a time to avoid CPU/RAM OOM (Out of Memory)
12
  for line_img in lines:
13
  try:
14
- # Resize line to 384px height (standard for TrOCR) to save processing time
15
  w, h = line_img.size
16
  new_h = 384
17
  new_w = int((new_h / h) * w)
@@ -26,9 +73,21 @@ def full_pipeline(input_img):
26
 
27
  if text.strip():
28
  final_transcript.append(text.strip())
29
-
30
- except Exception as e:
31
- print(f"Error processing line: {e}")
32
- continue
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- return " ".join(final_transcript)
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
7
+
8
+ # 1. Setup - Using 'base' instead of 'large' to prevent RAM crashes on Hugging Face
9
+ # This model is ~1GB smaller and significantly faster on CPUs.
10
+ device = "cpu"
11
+ model_id = 'microsoft/trocr-base-handwritten'
12
+
13
+ print(f"Loading model {model_id}...")
14
+ processor = TrOCRProcessor.from_pretrained(model_id)
15
+ model = VisionEncoderDecoderModel.from_pretrained(model_id).to(device)
16
+
17
+ def get_lines_from_image(img_array):
18
+ # Convert to grayscale
19
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
20
+
21
+ # Adaptive thresholding handles uneven lighting better than global thresholding
22
+ binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
23
+ cv2.THRESH_BINARY_INV, 11, 2)
24
+
25
+ # Dilate horizontally to join characters into lines
26
+ kernel = np.ones((5, 80), np.uint8)
27
+ dilation = cv2.dilate(binary, kernel, iterations=1)
28
+
29
+ # Find contours for line segmentation
30
+ contours, _ = cv2.findContours(dilation, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
31
+ contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
32
+
33
+ line_images = []
34
+ for ctr in contours:
35
+ x, y, w, h = cv2.boundingRect(ctr)
36
+ # Filter out tiny noise
37
+ if h > 20 and w > 20:
38
+ # Add small padding
39
+ y_s, y_e = max(0, y-5), min(img_array.shape[0], y+h+5)
40
+ x_s, x_e = max(0, x-5), min(img_array.shape[1], x+w+5)
41
+
42
+ roi = img_array[y_s:y_e, x_s:x_s+w]
43
+ # Ensure RGB for PIL
44
+ line_images.append(Image.fromarray(roi).convert("RGB"))
45
+
46
+ return line_images
47
+
48
  def full_pipeline(input_img):
49
  if input_img is None:
50
  return "Please upload an image."
51
 
52
  lines = get_lines_from_image(input_img)
53
  if not lines:
54
+ return "No text lines detected. Please ensure your image is clear and not too dark."
55
 
56
  final_transcript = []
57
 
58
+ # Process sequentially to keep memory usage low and stable
59
  for line_img in lines:
60
  try:
61
+ # Resizing to 384 height helps TrOCR's internal attention mechanism
62
  w, h = line_img.size
63
  new_h = 384
64
  new_w = int((new_h / h) * w)
 
73
 
74
  if text.strip():
75
  final_transcript.append(text.strip())
76
+ except Exception:
77
+ continue # Skip lines that fail to avoid crashing the whole process
78
+
79
+ return "\n".join(final_transcript)
80
+
81
+ # Gradio Interface
82
+ demo = gr.Interface(
83
+ fn=full_pipeline,
84
+ # 'editor' allows users to fix orientation/crop before submitting
85
+ inputs=gr.Image(label="Upload Handwriting", type="numpy"),
86
+ outputs=gr.Textbox(label="Typed Text", show_copy_button=True),
87
+ title="Handwritten Paragraph to Typed Text",
88
+ description="Optimized for CPU. Upload a clear image of handwritten text. Tip: Crop the image to just the text area for best results.",
89
+ allow_flagging="never"
90
+ )
91
 
92
+ if __name__ == "__main__":
93
+ demo.launch()