iammraat commited on
Commit
a29782b
·
verified ·
1 Parent(s): b00ad18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -43
app.py CHANGED
@@ -1,65 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  import torch
 
 
4
  from PIL import Image
 
 
5
 
6
- # --- Model Setup ---
7
- # We load the model outside the inference function to cache it on startup
8
- MODEL_ID = "microsoft/trocr-base-handwritten"
9
-
10
- print(f"Loading {MODEL_ID}...")
11
- processor = TrOCRProcessor.from_pretrained(MODEL_ID)
12
- model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID)
13
-
14
- # Check for GPU (Free Spaces are usually CPU-only, but this handles upgrades)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model.to(device)
17
- print(f"Model loaded on device: {device}")
18
 
19
- # --- Inference Function ---
20
- def process_image(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  if image is None:
22
  return "Please upload an image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- try:
25
- # 1. Convert to RGB (standardizes input)
26
- image = image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # 2. Preprocess
29
- pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
 
 
 
 
 
 
30
 
31
- # 3. Generate text
 
32
  generated_ids = model.generate(pixel_values)
33
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
 
35
- return generated_text
36
- except Exception as e:
37
- return f"Error: {str(e)}"
38
 
39
- # --- Gradio Interface ---
40
- # Using the Blocks API for a clean layout
 
 
 
 
41
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
42
- gr.Markdown(
43
- """
44
- # ✍️ Handwritten Text Recognition
45
- Using Microsoft's **TrOCR Small** model. Upload a handwritten note to transcribe it.
46
- """
47
- )
48
 
49
  with gr.Row():
50
- with gr.Column():
51
- input_img = gr.Image(type="pil", label="Upload Image")
52
- submit_btn = gr.Button("Transcribe", variant="primary")
53
 
54
- with gr.Column():
55
- output_text = gr.Textbox(label="Result", interactive=False)
56
-
57
- # Examples help users test it immediately without uploading their own file
58
- # (Uncomment the list below if you upload example images to your repo)
59
- # gr.Examples(["sample1.jpg"], inputs=input_img)
60
-
61
- submit_btn.click(fn=process_image, inputs=input_img, outputs=output_text)
62
 
63
- # Launch for Spaces
64
  if __name__ == "__main__":
65
  demo.launch()
 
1
+ # import gradio as gr
2
+ # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ # import torch
4
+ # from PIL import Image
5
+
6
+ # # --- Model Setup ---
7
+ # # We load the model outside the inference function to cache it on startup
8
+ # MODEL_ID = "microsoft/trocr-base-handwritten"
9
+
10
+ # print(f"Loading {MODEL_ID}...")
11
+ # processor = TrOCRProcessor.from_pretrained(MODEL_ID)
12
+ # model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID)
13
+
14
+ # # Check for GPU (Free Spaces are usually CPU-only, but this handles upgrades)
15
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ # model.to(device)
17
+ # print(f"Model loaded on device: {device}")
18
+
19
+ # # --- Inference Function ---
20
+ # def process_image(image):
21
+ # if image is None:
22
+ # return "Please upload an image."
23
+
24
+ # try:
25
+ # # 1. Convert to RGB (standardizes input)
26
+ # image = image.convert("RGB")
27
+
28
+ # # 2. Preprocess
29
+ # pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
30
+
31
+ # # 3. Generate text
32
+ # generated_ids = model.generate(pixel_values)
33
+ # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
+
35
+ # return generated_text
36
+ # except Exception as e:
37
+ # return f"Error: {str(e)}"
38
+
39
+ # # --- Gradio Interface ---
40
+ # # Using the Blocks API for a clean layout
41
+ # with gr.Blocks(theme=gr.themes.Soft()) as demo:
42
+ # gr.Markdown(
43
+ # """
44
+ # # ✍️ Handwritten Text Recognition
45
+ # Using Microsoft's **TrOCR Small** model. Upload a handwritten note to transcribe it.
46
+ # """
47
+ # )
48
+
49
+ # with gr.Row():
50
+ # with gr.Column():
51
+ # input_img = gr.Image(type="pil", label="Upload Image")
52
+ # submit_btn = gr.Button("Transcribe", variant="primary")
53
+
54
+ # with gr.Column():
55
+ # output_text = gr.Textbox(label="Result", interactive=False)
56
+
57
+ # # Examples help users test it immediately without uploading their own file
58
+ # # (Uncomment the list below if you upload example images to your repo)
59
+ # # gr.Examples(["sample1.jpg"], inputs=input_img)
60
+
61
+ # submit_btn.click(fn=process_image, inputs=input_img, outputs=output_text)
62
+
63
+ # # Launch for Spaces
64
+ # if __name__ == "__main__":
65
+ # demo.launch()
66
+
67
+
68
+
69
+
70
+
71
+
72
  import gradio as gr
 
73
  import torch
74
+ import numpy as np
75
+ import cv2
76
  from PIL import Image
77
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
78
+ from craft_text_detector import Craft
79
 
80
+ # --- 1. Load TrOCR (Recognition) ---
81
+ print("Loading TrOCR model...")
82
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
83
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten')
 
 
 
 
 
84
  device = "cuda" if torch.cuda.is_available() else "cpu"
85
  model.to(device)
 
86
 
87
+ # --- 2. Load CRAFT (Detection) ---
88
+ print("Loading CRAFT text detector...")
89
+ # refine_net=True helps connect individual characters into words/lines
90
+ craft = Craft(output_dir=None, crop_type="poly", cuda=(device == "cuda"))
91
+
92
+ # --- Helper: Sort Boxes (Reading Order) ---
93
+ def get_sorted_boxes(boxes):
94
+ """
95
+ Sort boxes from top-to-bottom, then left-to-right.
96
+ This simple sorting assumes lines are roughly horizontal.
97
+ """
98
+ # Calculate centroids
99
+ centroids = []
100
+ for box in boxes:
101
+ # box is usually [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
102
+ # Get center x and y
103
+ x_center = np.mean(box[:, 0])
104
+ y_center = np.mean(box[:, 1])
105
+ centroids.append([x_center, y_center, box])
106
+
107
+ # Sort by Y first (with a tolerance to group items on same line)
108
+ # This is a naive sort; for complex layouts, more advanced logic is needed.
109
+ centroids.sort(key=lambda k: (int(k[1] // 20), k[0]))
110
+
111
+ return [item[2] for item in centroids]
112
+
113
+ # --- Main Inference Pipeline ---
114
+ def process_full_page(image):
115
  if image is None:
116
  return "Please upload an image."
117
+
118
+ # Convert PIL to Numpy (OpenCV format)
119
+ image_np = np.array(image)
120
+
121
+ # 1. DETECT TEXT REGIONS
122
+ # prediction_result returns: {"boxes": [...], "polys": [...], "heatmaps": ...}
123
+ prediction_result = craft.detect_text(image_np)
124
+ boxes = prediction_result["boxes"]
125
+
126
+ if len(boxes) == 0:
127
+ return "No text detected."
128
+
129
+ # 2. SORT BOXES (Reading Order)
130
+ sorted_boxes = get_sorted_boxes(boxes)
131
+
132
+ # 3. RECOGNIZE TEXT (Iterate through crops)
133
+ full_text = []
134
 
135
+ # Optional: Draw boxes on image for visualization
136
+ annotated_img = image_np.copy()
137
+
138
+ for box in sorted_boxes:
139
+ # Get coordinates for cropping
140
+ # box points are float, convert to int
141
+ box = box.astype(int)
142
+
143
+ # Draw box on visualization
144
+ cv2.polylines(annotated_img, [box], True, (255, 0, 0), 2)
145
+
146
+ # Crop the region
147
+ x_min = max(0, np.min(box[:, 0]))
148
+ x_max = min(image_np.shape[1], np.max(box[:, 0]))
149
+ y_min = max(0, np.min(box[:, 1]))
150
+ y_max = min(image_np.shape[0], np.max(box[:, 1]))
151
 
152
+ # Safety check for empty crops
153
+ if x_max - x_min < 5 or y_max - y_min < 5:
154
+ continue
155
+
156
+ cropped_region = image_np[y_min:y_max, x_min:x_max]
157
+
158
+ # Convert crop back to PIL for TrOCR
159
+ pil_crop = Image.fromarray(cropped_region).convert("RGB")
160
 
161
+ # Run TrOCR
162
+ pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
163
  generated_ids = model.generate(pixel_values)
164
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
165
 
166
+ full_text.append(text)
 
 
167
 
168
+ # Join detected pieces
169
+ final_output = " ".join(full_text)
170
+
171
+ return Image.fromarray(annotated_img), final_output
172
+
173
+ # --- Gradio UI ---
174
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
175
+ gr.Markdown("# 🕵️‍♀️ Full-Page Handwritten OCR")
176
+ gr.Markdown("Pipeline: **CRAFT** (Detection) ➡️ **TrOCR** (Recognition)")
 
 
 
 
177
 
178
  with gr.Row():
179
+ input_img = gr.Image(type="pil", label="Upload Full Page")
 
 
180
 
181
+ with gr.Row():
182
+ vis_output = gr.Image(label="Detections", type="pil")
183
+ text_output = gr.Textbox(label="Extracted Text", lines=10)
184
+
185
+ submit_btn = gr.Button("Process Page", variant="primary")
186
+ submit_btn.click(fn=process_full_page, inputs=input_img, outputs=[vis_output, text_output])
 
 
187
 
 
188
  if __name__ == "__main__":
189
  demo.launch()