iammraat commited on
Commit
00ef556
·
verified ·
1 Parent(s): 2a2ba6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -63
app.py CHANGED
@@ -94,81 +94,258 @@
94
 
95
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  import gradio as gr
98
  import numpy as np
99
  import cv2
100
  import traceback
101
  import tempfile
102
  import os
103
- from PIL import Image
104
  from doctr.io import DocumentFile
105
  from doctr.models import ocr_predictor
106
- from transformers import pipeline
107
 
108
  # ------------------------------------------------------
109
- # 1. Load Models Globally
110
  # ------------------------------------------------------
111
  print("⏳ Loading models...")
112
 
113
  # A. Load DocTR (OCR)
114
  try:
115
- # 'fast_base' is lightweight for CPU
116
  ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True)
117
  print("✅ DocTR loaded.")
118
  except Exception as e:
119
  print(f"❌ DocTR Load Error: {e}")
120
  raise e
121
 
122
- # B. Load Corrector (Small Language Model)
 
 
 
 
123
  try:
124
- # 'google/flan-t5-small' is ~250MB, well under the 1GB limit.
125
- # We use a text2text-generation pipeline.
126
- corrector = pipeline(
127
- "text2text-generation",
128
- model="google/flan-t5-small",
129
- device=-1 # -1 forces CPU
130
  )
131
- print("✅ Correction model (Flan-T5-Small) loaded.")
132
  except Exception as e:
133
- print(f"❌ Corrector Load Error: {e}")
134
- corrector = None
 
135
 
136
  # ------------------------------------------------------
137
- # 2. Correction Logic
138
  # ------------------------------------------------------
139
  def smart_correction(text):
140
- if not text or not text.strip() or corrector is None:
141
  return text
 
 
142
 
143
- # DocTR returns text with newlines. LLMs often prefer line-by-line or chunked input
144
- # if the context isn't massive. For a small model, processing line-by-line is safer.
145
- lines = text.split('\n')
146
- corrected_lines = []
147
-
148
- print("--- Starting Correction ---")
149
- for line in lines:
150
- if len(line.strip()) < 3: # Skip empty/tiny lines
151
- corrected_lines.append(line)
152
- continue
153
-
154
- try:
155
- # Prompt engineering for Flan-T5
156
- prompt = f"Fix grammar and OCR errors: {line}"
157
-
158
- # max_length ensures it doesn't ramble.
159
- result = corrector(prompt, max_length=128)
160
- fixed_text = result[0]['generated_text']
161
-
162
- # Fallback: if model returns empty, keep original
163
- corrected_lines.append(fixed_text if fixed_text else line)
164
- except Exception as e:
165
- print(f"Correction failed for line '{line}': {e}")
166
- corrected_lines.append(line)
167
-
168
- return "\n".join(corrected_lines)
 
 
 
 
 
 
 
 
 
169
 
170
  # ------------------------------------------------------
171
- # 3. Main Processing Function
172
  # ------------------------------------------------------
173
  def run_ocr(input_image):
174
  tmp_path = None
@@ -176,22 +353,21 @@ def run_ocr(input_image):
176
  if input_image is None:
177
  return None, "No image uploaded", None, None
178
 
179
- # -- Save temp file for DocTR robustness --
180
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
181
  input_image.save(tmp.name)
182
  tmp_path = tmp.name
183
 
184
- # -- Run OCR --
185
  doc = DocumentFile.from_images(tmp_path)
186
  result = ocr_model(doc)
187
-
188
- # -- Raw Text --
189
  raw_text = result.render()
190
 
191
- # -- Correction Step --
 
192
  corrected_text = smart_correction(raw_text)
193
 
194
- # -- Visualization --
195
  image_np = np.array(input_image)
196
  viz_image = image_np.copy()
197
 
@@ -209,40 +385,36 @@ def run_ocr(input_image):
209
 
210
  except Exception as e:
211
  error_log = traceback.format_exc()
212
- return None, f"Error: {e}", f"Error Log:\n{error_log}", {"error": str(e)}
213
 
214
  finally:
215
  if tmp_path and os.path.exists(tmp_path):
216
  os.remove(tmp_path)
217
 
218
  # ------------------------------------------------------
219
- # 4. Gradio UI
220
  # ------------------------------------------------------
221
- with gr.Blocks(title="DocTR OCR + Correction") as demo:
222
- gr.Markdown("## 📄 AI OCR with Grammar Correction")
223
- gr.Markdown("Using `DocTR` for extraction and `Flan-T5-Small` for correction.")
224
 
225
  with gr.Row():
226
  input_img = gr.Image(type="pil", label="Upload Document")
227
 
228
  with gr.Row():
229
- btn = gr.Button("Run Extraction & Correction", variant="primary")
230
 
231
  with gr.Row():
232
  out_img = gr.Image(label="Detections")
233
 
234
  with gr.Row():
235
- out_raw = gr.Textbox(label="Raw OCR Text", lines=8, placeholder="Raw output appears here...")
236
- out_corrected = gr.Textbox(label=" Corrected Text", lines=8, placeholder="AI corrected output appears here...")
237
 
238
  with gr.Row():
239
- out_json = gr.JSON(label="Full JSON Data")
240
 
241
- btn.click(
242
- fn=run_ocr,
243
- inputs=input_img,
244
- outputs=[out_img, out_raw, out_corrected, out_json]
245
- )
246
 
247
  if __name__ == "__main__":
248
  demo.launch()
 
94
 
95
 
96
 
97
+
98
+
99
+
100
+ # import gradio as gr
101
+ # import numpy as np
102
+ # import cv2
103
+ # import traceback
104
+ # import tempfile
105
+ # import os
106
+ # from PIL import Image
107
+ # from doctr.io import DocumentFile
108
+ # from doctr.models import ocr_predictor
109
+ # from transformers import pipeline
110
+
111
+ # # ------------------------------------------------------
112
+ # # 1. Load Models Globally
113
+ # # ------------------------------------------------------
114
+ # print("⏳ Loading models...")
115
+
116
+ # # A. Load DocTR (OCR)
117
+ # try:
118
+ # # 'fast_base' is lightweight for CPU
119
+ # ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True)
120
+ # print("✅ DocTR loaded.")
121
+ # except Exception as e:
122
+ # print(f"❌ DocTR Load Error: {e}")
123
+ # raise e
124
+
125
+ # # B. Load Corrector (Small Language Model)
126
+ # try:
127
+ # # 'google/flan-t5-small' is ~250MB, well under the 1GB limit.
128
+ # # We use a text2text-generation pipeline.
129
+ # corrector = pipeline(
130
+ # "text2text-generation",
131
+ # model="google/flan-t5-small",
132
+ # device=-1 # -1 forces CPU
133
+ # )
134
+ # print("✅ Correction model (Flan-T5-Small) loaded.")
135
+ # except Exception as e:
136
+ # print(f"❌ Corrector Load Error: {e}")
137
+ # corrector = None
138
+
139
+ # # ------------------------------------------------------
140
+ # # 2. Correction Logic
141
+ # # ------------------------------------------------------
142
+ # def smart_correction(text):
143
+ # if not text or not text.strip() or corrector is None:
144
+ # return text
145
+
146
+ # # DocTR returns text with newlines. LLMs often prefer line-by-line or chunked input
147
+ # # if the context isn't massive. For a small model, processing line-by-line is safer.
148
+ # lines = text.split('\n')
149
+ # corrected_lines = []
150
+
151
+ # print("--- Starting Correction ---")
152
+ # for line in lines:
153
+ # if len(line.strip()) < 3: # Skip empty/tiny lines
154
+ # corrected_lines.append(line)
155
+ # continue
156
+
157
+ # try:
158
+ # # Prompt engineering for Flan-T5
159
+ # prompt = f"Fix grammar and OCR errors: {line}"
160
+
161
+ # # max_length ensures it doesn't ramble.
162
+ # result = corrector(prompt, max_length=128)
163
+ # fixed_text = result[0]['generated_text']
164
+
165
+ # # Fallback: if model returns empty, keep original
166
+ # corrected_lines.append(fixed_text if fixed_text else line)
167
+ # except Exception as e:
168
+ # print(f"Correction failed for line '{line}': {e}")
169
+ # corrected_lines.append(line)
170
+
171
+ # return "\n".join(corrected_lines)
172
+
173
+ # # ------------------------------------------------------
174
+ # # 3. Main Processing Function
175
+ # # ------------------------------------------------------
176
+ # def run_ocr(input_image):
177
+ # tmp_path = None
178
+ # try:
179
+ # if input_image is None:
180
+ # return None, "No image uploaded", None, None
181
+
182
+ # # -- Save temp file for DocTR robustness --
183
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
184
+ # input_image.save(tmp.name)
185
+ # tmp_path = tmp.name
186
+
187
+ # # -- Run OCR --
188
+ # doc = DocumentFile.from_images(tmp_path)
189
+ # result = ocr_model(doc)
190
+
191
+ # # -- Raw Text --
192
+ # raw_text = result.render()
193
+
194
+ # # -- Correction Step --
195
+ # corrected_text = smart_correction(raw_text)
196
+
197
+ # # -- Visualization --
198
+ # image_np = np.array(input_image)
199
+ # viz_image = image_np.copy()
200
+
201
+ # for page in result.pages:
202
+ # for block in page.blocks:
203
+ # for line in block.lines:
204
+ # for word in line.words:
205
+ # h, w = viz_image.shape[:2]
206
+ # (x_min, y_min), (x_max, y_max) = word.geometry
207
+ # x1, y1 = int(x_min * w), int(y_min * h)
208
+ # x2, y2 = int(x_max * w), int(y_max * h)
209
+ # cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
210
+
211
+ # return viz_image, raw_text, corrected_text, result.export()
212
+
213
+ # except Exception as e:
214
+ # error_log = traceback.format_exc()
215
+ # return None, f"Error: {e}", f"Error Log:\n{error_log}", {"error": str(e)}
216
+
217
+ # finally:
218
+ # if tmp_path and os.path.exists(tmp_path):
219
+ # os.remove(tmp_path)
220
+
221
+ # # ------------------------------------------------------
222
+ # # 4. Gradio UI
223
+ # # ------------------------------------------------------
224
+ # with gr.Blocks(title="DocTR OCR + Correction") as demo:
225
+ # gr.Markdown("## 📄 AI OCR with Grammar Correction")
226
+ # gr.Markdown("Using `DocTR` for extraction and `Flan-T5-Small` for correction.")
227
+
228
+ # with gr.Row():
229
+ # input_img = gr.Image(type="pil", label="Upload Document")
230
+
231
+ # with gr.Row():
232
+ # btn = gr.Button("Run Extraction & Correction", variant="primary")
233
+
234
+ # with gr.Row():
235
+ # out_img = gr.Image(label="Detections")
236
+
237
+ # with gr.Row():
238
+ # out_raw = gr.Textbox(label="Raw OCR Text", lines=8, placeholder="Raw output appears here...")
239
+ # out_corrected = gr.Textbox(label="✨ Corrected Text", lines=8, placeholder="AI corrected output appears here...")
240
+
241
+ # with gr.Row():
242
+ # out_json = gr.JSON(label="Full JSON Data")
243
+
244
+ # btn.click(
245
+ # fn=run_ocr,
246
+ # inputs=input_img,
247
+ # outputs=[out_img, out_raw, out_corrected, out_json]
248
+ # )
249
+
250
+ # if __name__ == "__main__":
251
+ # demo.launch()
252
+
253
+
254
+
255
+
256
+
257
+
258
+
259
  import gradio as gr
260
  import numpy as np
261
  import cv2
262
  import traceback
263
  import tempfile
264
  import os
265
+ import torch
266
  from doctr.io import DocumentFile
267
  from doctr.models import ocr_predictor
268
+ from transformers import AutoModelForCausalLM, AutoTokenizer
269
 
270
  # ------------------------------------------------------
271
+ # 1. Configuration & Global Loading
272
  # ------------------------------------------------------
273
  print("⏳ Loading models...")
274
 
275
  # A. Load DocTR (OCR)
276
  try:
 
277
  ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True)
278
  print("✅ DocTR loaded.")
279
  except Exception as e:
280
  print(f"❌ DocTR Load Error: {e}")
281
  raise e
282
 
283
+ # B. Load LLM (Qwen2.5-7B-Instruct)
284
+ # With 50GB RAM, we can load this comfortably.
285
+ # If it is too slow, change MODEL_ID to "Qwen/Qwen2.5-3B-Instruct" or "Qwen/Qwen2.5-1.5B-Instruct"
286
+ MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
287
+
288
  try:
289
+ print(f"⬇️ Downloading & Loading {MODEL_ID}...")
290
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
291
+ llm_model = AutoModelForCausalLM.from_pretrained(
292
+ MODEL_ID,
293
+ torch_dtype="auto",
294
+ device_map="cpu" # Uses your 50GB System RAM
295
  )
296
+ print(f"✅ {MODEL_ID} loaded successfully.")
297
  except Exception as e:
298
+ print(f"❌ LLM Load Error: {e}")
299
+ llm_model = None
300
+ tokenizer = None
301
 
302
  # ------------------------------------------------------
303
+ # 2. Correction Logic (The "Smart" Fix)
304
  # ------------------------------------------------------
305
  def smart_correction(text):
306
+ if not text or not llm_model:
307
  return text
308
+
309
+ print("--- Starting AI Correction ---")
310
 
311
+ # 1. Construct the Prompt
312
+ # We ask the model to act as a text editor.
313
+ system_prompt = "You are a helpful assistant that corrects OCR text. Fix typos, capitalization, and grammar. Maintain the original line structure. Do not add any conversational text like 'Here is the corrected text'."
314
+ user_prompt = f"Correct the following OCR text:\n\n{text}"
315
+
316
+ messages = [
317
+ {"role": "system", "content": system_prompt},
318
+ {"role": "user", "content": user_prompt}
319
+ ]
320
+
321
+ text_input = tokenizer.apply_chat_template(
322
+ messages,
323
+ tokenize=False,
324
+ add_generation_prompt=True
325
+ )
326
+
327
+ model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu")
328
+
329
+ # 2. Run Inference
330
+ # max_new_tokens limits the output length to avoid infinite loops
331
+ generated_ids = llm_model.generate(
332
+ model_inputs.input_ids,
333
+ max_new_tokens=1024,
334
+ temperature=0.1, # Low temp for factual/consistent results
335
+ do_sample=False # Greedy decoding is faster and more deterministic
336
+ )
337
+
338
+ # 3. Decode Output
339
+ # We strip the input tokens to get only the new (corrected) text
340
+ generated_ids = [
341
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
342
+ ]
343
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
344
+
345
+ return response
346
 
347
  # ------------------------------------------------------
348
+ # 3. Processing Pipeline
349
  # ------------------------------------------------------
350
  def run_ocr(input_image):
351
  tmp_path = None
 
353
  if input_image is None:
354
  return None, "No image uploaded", None, None
355
 
356
+ # Robust Temp File Handling
357
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
358
  input_image.save(tmp.name)
359
  tmp_path = tmp.name
360
 
361
+ # 1. Run OCR
362
  doc = DocumentFile.from_images(tmp_path)
363
  result = ocr_model(doc)
 
 
364
  raw_text = result.render()
365
 
366
+ # 2. Run AI Correction
367
+ # We pass the WHOLE text block at once. Context helps the AI.
368
  corrected_text = smart_correction(raw_text)
369
 
370
+ # 3. Visualization
371
  image_np = np.array(input_image)
372
  viz_image = image_np.copy()
373
 
 
385
 
386
  except Exception as e:
387
  error_log = traceback.format_exc()
388
+ return None, f"Error: {e}", f"Logs:\n{error_log}", {"error": str(e)}
389
 
390
  finally:
391
  if tmp_path and os.path.exists(tmp_path):
392
  os.remove(tmp_path)
393
 
394
  # ------------------------------------------------------
395
+ # 4. Gradio Interface
396
  # ------------------------------------------------------
397
+ with gr.Blocks(title="Next-Gen OCR") as demo:
398
+ gr.Markdown("## 📄 Next-Gen AI OCR")
399
+ gr.Markdown(f"Using **DocTR** for extraction and **{MODEL_ID}** for smart correction.")
400
 
401
  with gr.Row():
402
  input_img = gr.Image(type="pil", label="Upload Document")
403
 
404
  with gr.Row():
405
+ btn = gr.Button("Run Extraction & Smart Correction", variant="primary")
406
 
407
  with gr.Row():
408
  out_img = gr.Image(label="Detections")
409
 
410
  with gr.Row():
411
+ out_raw = gr.Textbox(label="Raw OCR Output", lines=10)
412
+ out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 7B)", lines=10)
413
 
414
  with gr.Row():
415
+ out_json = gr.JSON(label="JSON Data")
416
 
417
+ btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_raw, out_corrected, out_json])
 
 
 
 
418
 
419
  if __name__ == "__main__":
420
  demo.launch()