iammraat commited on
Commit
bcd5374
·
verified ·
1 Parent(s): 1978033

Update app (1).py

Browse files
Files changed (1) hide show
  1. app (1).py +51 -217
app (1).py CHANGED
@@ -1,231 +1,65 @@
1
-
2
- # import os
3
- # import sys
4
- # import numpy as np
5
- # import cv2
6
-
7
- # # ==========================================
8
- # # 🔧 PATCH 1: Fix Torchvision Compatibility
9
- # # ==========================================
10
- # import torchvision.models.vgg
11
- # if not hasattr(torchvision.models.vgg, 'model_urls'):
12
- # torchvision.models.vgg.model_urls = {
13
- # 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
14
- # }
15
-
16
- # # ==========================================
17
- # # 🔧 PATCH 2: Fix NumPy Crash AND Coordinates
18
- # # ==========================================
19
- # import craft_text_detector.craft_utils as craft_utils_module
20
-
21
- # # CRITICAL FIX: Added 'ratio_net=2' to the signature and calculation.
22
- # # Without this, your boxes are 2x smaller than they should be.
23
- # def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
24
- # if not polys:
25
- # return []
26
-
27
- # adjusted = []
28
- # for poly in polys:
29
- # if poly is None or len(poly) == 0:
30
- # continue
31
-
32
- # # Convert to numpy array safely
33
- # poly = np.array(poly).reshape(-1, 2)
34
-
35
- # # Scale coordinates:
36
- # # 1. ratio_net scales from Heatmap -> Resized Image
37
- # # 2. ratio_w/h scales from Resized Image -> Original Image
38
- # poly[:, 0] *= (ratio_w * ratio_net)
39
- # poly[:, 1] *= (ratio_h * ratio_net)
40
-
41
- # adjusted.append(poly)
42
-
43
- # return adjusted
44
-
45
- # # Apply the patch
46
- # craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
47
- # # ==========================================
48
-
49
- # import gradio as gr
50
- # from craft_hw_ocr import OCR
51
-
52
- # print("⏳ Loading OCR models...")
53
- # ocr = OCR.load_models()
54
- # print("✅ Models loaded!")
55
-
56
- # def do_ocr(inp):
57
- # if inp is None:
58
- # return None, "No image uploaded."
59
-
60
- # try:
61
- # # Detection
62
- # # 'detected_img' is usually just the input image in this library
63
- # detected_img, results = OCR.detection(inp, ocr[2])
64
-
65
- # # Visualization
66
- # # Now that coordinates are scaled correctly, this should look right
67
- # viz_img = OCR.visualize(inp, results)
68
-
69
- # # Recognition
70
- # try:
71
- # # Note: The library has a typo 'recoginition' (extra 'i')
72
- # bboxes, text = OCR.recoginition(detected_img, results, ocr[0], ocr[1])
73
- # except Exception as e:
74
- # print(f"Recognition error: {e}")
75
- # text = f"Detection successful, but recognition failed: {str(e)}"
76
-
77
- # return viz_img, text
78
-
79
- # except Exception as e:
80
- # print(f"OCR error: {e}")
81
- # return None, f"Error processing image: {str(e)}"
82
-
83
- # inputs = gr.Image(label="Upload Image")
84
- # o1 = gr.Image(label="Detections")
85
- # o2 = gr.Textbox(label="Text")
86
-
87
- # title = "CRAFT-OCR (Fixed Coords)"
88
- # description = "Handwriting OCR using CRAFT + TrOCR. Patched for NumPy and Coordinates."
89
-
90
- # gr.Interface(
91
- # fn=do_ocr,
92
- # inputs=inputs,
93
- # outputs=[o1, o2],
94
- # title=title,
95
- # description=description
96
- # ).launch()
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
- import os
105
- import sys
106
- import numpy as np
107
- import cv2
108
-
109
- # ==========================================
110
- # 🔧 PATCH 1: Fix Torchvision Compatibility
111
- # ==========================================
112
- import torchvision.models.vgg
113
- if not hasattr(torchvision.models.vgg, 'model_urls'):
114
- torchvision.models.vgg.model_urls = {
115
- 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
116
- }
117
-
118
- # ==========================================
119
- # 🔧 PATCH 2: Fix NumPy Crash AND Coordinates
120
- # ==========================================
121
- import craft_text_detector.craft_utils as craft_utils_module
122
-
123
- def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
124
- if not polys:
125
- return []
126
-
127
- adjusted = []
128
- for poly in polys:
129
- if poly is None or len(poly) == 0:
130
- continue
131
-
132
- # Convert to numpy array safely
133
- poly = np.array(poly).reshape(-1, 2)
134
-
135
- # Scale coordinates correctly
136
- poly[:, 0] *= (ratio_w * ratio_net)
137
- poly[:, 1] *= (ratio_h * ratio_net)
138
-
139
- adjusted.append(poly)
140
-
141
- return adjusted
142
-
143
- craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
144
- # ==========================================
145
-
146
  import gradio as gr
147
- from craft_hw_ocr import OCR
148
- # Import the core prediction function to bypass the wrapper limitations
149
- from craft_text_detector.predict import get_prediction
150
-
151
- print("⏳ Loading OCR models...")
152
- ocr = OCR.load_models()
153
- # ocr[2] is the Craft object wrapper
154
- craft_wrapper = ocr[2]
155
- # We extract the actual networks to run them manually
156
- craft_net = craft_wrapper.craft_net
157
- refine_net = craft_wrapper.refine_net
158
- print("✅ Models loaded!")
159
-
160
- def do_ocr(inp, text_threshold, link_threshold, low_text):
161
- if inp is None:
162
- return None, "No image uploaded."
 
 
 
 
 
163
 
164
  try:
165
- print(f"⚙️ Running Direct Inference: Text={text_threshold}, Link={link_threshold}, Low={low_text}")
166
-
167
- # 1. Direct Detection (Bypassing Craft.detect_text)
168
- # This calls the engine directly, which accepts all our sliders.
169
- prediction_result = get_prediction(
170
- image=inp,
171
- craft_net=craft_net,
172
- refine_net=refine_net,
173
- text_threshold=text_threshold,
174
- link_threshold=link_threshold,
175
- low_text=low_text,
176
- cuda=False, # Space is CPU
177
- poly=True
178
- )
179
-
180
- # 2. Visualization
181
- # OCR.visualize expects (image, results_dict)
182
- viz_img = OCR.visualize(inp, prediction_result)
183
 
184
- # 3. Recognition
185
- try:
186
- bboxes, text = OCR.recoginition(inp, prediction_result, ocr[0], ocr[1])
187
- except Exception as e:
188
- text = f"Detection successful, but recognition failed: {e}"
189
 
190
- return viz_img, text
 
 
191
 
 
192
  except Exception as e:
193
- print(f"OCR error: {e}")
194
- return None, f"Error processing image: {str(e)}"
195
-
196
- # ------------------------------------------------------
197
- # 🎛️ UI with Tuning Sliders
198
- # ------------------------------------------------------
199
- with gr.Blocks(title="CRAFT-OCR Tuner") as demo:
200
- gr.Markdown("## 🔧 CRAFT-OCR Parameter Tuner")
201
- gr.Markdown("Adjust sliders to fix issues like merged words or background noise.")
 
 
202
 
203
  with gr.Row():
204
- with gr.Column(scale=1):
205
- input_img = gr.Image(label="Upload Image")
206
-
207
- gr.Markdown("### 🎛️ Fine-Tune Detection")
208
-
209
- text_thres = gr.Slider(0.1, 0.9, value=0.7, step=0.05, label="Text Threshold",
210
- info="Confidence to consider a pixel as text. Higher = Less Noise.")
211
-
212
- link_thres = gr.Slider(0.1, 0.9, value=0.4, step=0.05, label="Link Threshold",
213
- info="Confidence to link characters. HIGHER value splits merged words (Fixes 'Hamburgthen').")
214
-
215
- low_text = gr.Slider(0.1, 0.9, value=0.4, step=0.05, label="Low Text Threshold",
216
- info="Filters background noise. Higher = Cleaner background.")
217
-
218
- btn = gr.Button("Run OCR", variant="primary")
219
 
220
- with gr.Column(scale=1):
221
- viz_output = gr.Image(label="Detections (Verify Boxes)")
222
- text_output = gr.Textbox(label="Recognized Text", lines=10)
 
 
 
223
 
224
- btn.click(
225
- fn=do_ocr,
226
- inputs=[input_img, text_thres, link_thres, low_text],
227
- outputs=[viz_output, text_output]
228
- )
229
 
 
230
  if __name__ == "__main__":
231
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ import torch
4
+ from PIL import Image
5
+
6
+ # --- Model Setup ---
7
+ # We load the model outside the inference function to cache it on startup
8
+ MODEL_ID = "microsoft/trocr-small-handwritten"
9
+
10
+ print(f"Loading {MODEL_ID}...")
11
+ processor = TrOCRProcessor.from_pretrained(MODEL_ID)
12
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID)
13
+
14
+ # Check for GPU (Free Spaces are usually CPU-only, but this handles upgrades)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model.to(device)
17
+ print(f"Model loaded on device: {device}")
18
+
19
+ # --- Inference Function ---
20
+ def process_image(image):
21
+ if image is None:
22
+ return "Please upload an image."
23
 
24
  try:
25
+ # 1. Convert to RGB (standardizes input)
26
+ image = image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # 2. Preprocess
29
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
 
 
 
30
 
31
+ # 3. Generate text
32
+ generated_ids = model.generate(pixel_values)
33
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
 
35
+ return generated_text
36
  except Exception as e:
37
+ return f"Error: {str(e)}"
38
+
39
+ # --- Gradio Interface ---
40
+ # Using the Blocks API for a clean layout
41
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
42
+ gr.Markdown(
43
+ """
44
+ # ✍️ Handwritten Text Recognition
45
+ Using Microsoft's **TrOCR Small** model. Upload a handwritten note to transcribe it.
46
+ """
47
+ )
48
 
49
  with gr.Row():
50
+ with gr.Column():
51
+ input_img = gr.Image(type="pil", label="Upload Image")
52
+ submit_btn = gr.Button("Transcribe", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ with gr.Column():
55
+ output_text = gr.Textbox(label="Result", interactive=False)
56
+
57
+ # Examples help users test it immediately without uploading their own file
58
+ # (Uncomment the list below if you upload example images to your repo)
59
+ # gr.Examples(["sample1.jpg"], inputs=input_img)
60
 
61
+ submit_btn.click(fn=process_image, inputs=input_img, outputs=output_text)
 
 
 
 
62
 
63
+ # Launch for Spaces
64
  if __name__ == "__main__":
65
  demo.launch()