Sathiyapramod commited on
Commit
0cc296e
·
verified ·
1 Parent(s): a0cc850

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -23
app.py CHANGED
@@ -2,17 +2,39 @@ import gradio as gr
2
  from PIL import Image
3
  import numpy as np
4
  import cv2
 
5
 
6
- from transformers import AutoModel
7
- model = AutoModel.from_pretrained("deepseek-ai/DeepSeek-OCR-2", trust_remote_code=True, dtype="auto")
8
 
 
 
 
 
 
 
9
 
10
- def segment_lines(image):
11
- # Convert to OpenCV format
12
- img = np.array(image.convert("L"))
13
 
14
- # Threshold
15
- _, thresh = cv2.threshold(img, 150, 255, cv2.THRESH_BINARY_INV)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Horizontal projection
18
  horizontal_sum = np.sum(thresh, axis=1)
@@ -28,35 +50,76 @@ def segment_lines(image):
28
  lines.append((start, end))
29
  start = None
30
 
31
- # Extract line images
 
 
 
 
32
  line_images = []
33
  for (s, e) in lines:
34
- cropped = image.crop((0, s, image.width, e))
35
- line_images.append(cropped)
 
 
 
 
 
 
 
36
 
37
  return line_images
38
 
39
 
 
 
 
40
  def predict(image):
 
 
41
  if image is None:
42
- return "Upload an image"
 
 
 
 
 
 
 
 
 
43
 
44
- lines = segment_lines(image)
 
 
 
 
45
 
46
- results = []
 
 
 
 
47
 
48
- for line_img in lines:
49
- pixel_values = processor(images=line_img, return_tensors="pt").pixel_values
50
- generated_ids = model.generate(pixel_values)
51
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
52
- results.append(text)
53
 
54
- return "\n".join(results)
55
 
 
56
 
57
- gr.Interface(
 
 
 
 
 
 
 
58
  fn=predict,
59
- inputs=gr.Image(type="pil"),
60
  outputs=gr.Textbox(label="Extracted Text"),
61
- title="📝 Multi-line Handwritten OCR",
62
- ).launch()
 
 
 
 
 
2
  from PIL import Image
3
  import numpy as np
4
  import cv2
5
+ import torch
6
 
7
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
8
 
9
+ # =========================
10
+ # Model Loader (cached)
11
+ # =========================
12
+ processor = None
13
+ model = None
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ def load_model():
17
+ global processor, model
 
18
 
19
+ if processor is None or model is None:
20
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
21
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
22
+ model.to(device)
23
+
24
+
25
+ # =========================
26
+ # Line Segmentation Logic
27
+ # =========================
28
+ def segment_lines(image: Image.Image):
29
+ """
30
+ Splits image into individual text lines using horizontal projection
31
+ """
32
+
33
+ # Convert to grayscale
34
+ gray = np.array(image.convert("L"))
35
+
36
+ # Apply thresholding
37
+ _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)
38
 
39
  # Horizontal projection
40
  horizontal_sum = np.sum(thresh, axis=1)
 
50
  lines.append((start, end))
51
  start = None
52
 
53
+ # Edge case: last line
54
+ if start is not None:
55
+ lines.append((start, len(horizontal_sum)))
56
+
57
+ # Crop line images
58
  line_images = []
59
  for (s, e) in lines:
60
+ # Add small padding
61
+ top = max(0, s - 5)
62
+ bottom = min(image.height, e + 5)
63
+
64
+ cropped = image.crop((0, top, image.width, bottom))
65
+
66
+ # Skip very small/noise regions
67
+ if bottom - top > 10:
68
+ line_images.append(cropped)
69
 
70
  return line_images
71
 
72
 
73
+ # =========================
74
+ # OCR Prediction
75
+ # =========================
76
  def predict(image):
77
+ load_model()
78
+
79
  if image is None:
80
+ return "⚠️ Please upload an image."
81
+
82
+ try:
83
+ # Segment into lines
84
+ lines = segment_lines(image)
85
+
86
+ if not lines:
87
+ return "⚠️ No text detected. Try a clearer image."
88
+
89
+ results = []
90
 
91
+ for line_img in lines:
92
+ pixel_values = processor(
93
+ images=line_img,
94
+ return_tensors="pt"
95
+ ).pixel_values.to(device)
96
 
97
+ generated_ids = model.generate(pixel_values)
98
+ text = processor.batch_decode(
99
+ generated_ids,
100
+ skip_special_tokens=True
101
+ )[0]
102
 
103
+ results.append(text)
 
 
 
 
104
 
105
+ final_text = "\n".join(results)
106
 
107
+ return final_text if final_text.strip() else "⚠️ Could not extract text."
108
 
109
+ except Exception as e:
110
+ return f"❌ Error occurred: {str(e)}"
111
+
112
+
113
+ # =========================
114
+ # Gradio UI
115
+ # =========================
116
+ demo = gr.Interface(
117
  fn=predict,
118
+ inputs=gr.Image(type="pil", label="Upload Handwritten Image"),
119
  outputs=gr.Textbox(label="Extracted Text"),
120
+ title="📝 Handwritten OCR (Multi-line)",
121
+ description="Upload a handwritten note image. The model will extract text line by line.",
122
+ )
123
+
124
+ if __name__ == "__main__":
125
+ demo.launch()