iammraat commited on
Commit
79ff71f
·
verified ·
1 Parent(s): 6bf2b4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -29
app.py CHANGED
@@ -256,6 +256,184 @@
256
 
257
 
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  import gradio as gr
260
  import numpy as np
261
  import cv2
@@ -280,9 +458,8 @@ except Exception as e:
280
  print(f"❌ DocTR Load Error: {e}")
281
  raise e
282
 
283
- # B. Load LLM (Qwen2.5-7B-Instruct)
284
- # With 50GB RAM, we can load this comfortably.
285
- # If it is too slow, change MODEL_ID to "Qwen/Qwen2.5-3B-Instruct" or "Qwen/Qwen2.5-1.5B-Instruct"
286
  MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
287
 
288
  try:
@@ -291,7 +468,7 @@ try:
291
  llm_model = AutoModelForCausalLM.from_pretrained(
292
  MODEL_ID,
293
  torch_dtype="auto",
294
- device_map="cpu" # Uses your 50GB System RAM
295
  )
296
  print(f"✅ {MODEL_ID} loaded successfully.")
297
  except Exception as e:
@@ -300,7 +477,7 @@ except Exception as e:
300
  tokenizer = None
301
 
302
  # ------------------------------------------------------
303
- # 2. Correction Logic (The "Smart" Fix)
304
  # ------------------------------------------------------
305
  def smart_correction(text):
306
  if not text or not llm_model:
@@ -309,8 +486,13 @@ def smart_correction(text):
309
  print("--- Starting AI Correction ---")
310
 
311
  # 1. Construct the Prompt
312
- # We ask the model to act as a text editor.
313
- system_prompt = "You are a helpful assistant that corrects OCR text. Fix typos, capitalization, and grammar. Maintain the original line structure. Do not add any conversational text like 'Here is the corrected text'."
 
 
 
 
 
314
  user_prompt = f"Correct the following OCR text:\n\n{text}"
315
 
316
  messages = [
@@ -318,31 +500,37 @@ def smart_correction(text):
318
  {"role": "user", "content": user_prompt}
319
  ]
320
 
 
321
  text_input = tokenizer.apply_chat_template(
322
  messages,
323
  tokenize=False,
324
  add_generation_prompt=True
325
  )
326
 
 
327
  model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu")
328
 
329
  # 2. Run Inference
330
- # max_new_tokens limits the output length to avoid infinite loops
331
- generated_ids = llm_model.generate(
332
- model_inputs.input_ids,
333
- max_new_tokens=1024,
334
- temperature=0.1, # Low temp for factual/consistent results
335
- do_sample=False # Greedy decoding is faster and more deterministic
336
- )
337
-
338
- # 3. Decode Output
339
- # We strip the input tokens to get only the new (corrected) text
340
- generated_ids = [
341
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
342
- ]
343
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
344
-
345
- return response
 
 
 
 
346
 
347
  # ------------------------------------------------------
348
  # 3. Processing Pipeline
@@ -353,7 +541,7 @@ def run_ocr(input_image):
353
  if input_image is None:
354
  return None, "No image uploaded", None, None
355
 
356
- # Robust Temp File Handling
357
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
358
  input_image.save(tmp.name)
359
  tmp_path = tmp.name
@@ -364,7 +552,7 @@ def run_ocr(input_image):
364
  raw_text = result.render()
365
 
366
  # 2. Run AI Correction
367
- # We pass the WHOLE text block at once. Context helps the AI.
368
  corrected_text = smart_correction(raw_text)
369
 
370
  # 3. Visualization
@@ -394,9 +582,9 @@ def run_ocr(input_image):
394
  # ------------------------------------------------------
395
  # 4. Gradio Interface
396
  # ------------------------------------------------------
397
- with gr.Blocks(title="Next-Gen OCR") as demo:
398
- gr.Markdown("## 📄 Next-Gen AI OCR")
399
- gr.Markdown(f"Using **DocTR** for extraction and **{MODEL_ID}** for smart correction.")
400
 
401
  with gr.Row():
402
  input_img = gr.Image(type="pil", label="Upload Document")
@@ -409,7 +597,7 @@ with gr.Blocks(title="Next-Gen OCR") as demo:
409
 
410
  with gr.Row():
411
  out_raw = gr.Textbox(label="Raw OCR Output", lines=10)
412
- out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 7B)", lines=10)
413
 
414
  with gr.Row():
415
  out_json = gr.JSON(label="JSON Data")
 
256
 
257
 
258
 
259
+ # import gradio as gr
260
+ # import numpy as np
261
+ # import cv2
262
+ # import traceback
263
+ # import tempfile
264
+ # import os
265
+ # import torch
266
+ # from doctr.io import DocumentFile
267
+ # from doctr.models import ocr_predictor
268
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
269
+
270
+ # # ------------------------------------------------------
271
+ # # 1. Configuration & Global Loading
272
+ # # ------------------------------------------------------
273
+ # print("⏳ Loading models...")
274
+
275
+ # # A. Load DocTR (OCR)
276
+ # try:
277
+ # ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True)
278
+ # print("✅ DocTR loaded.")
279
+ # except Exception as e:
280
+ # print(f"❌ DocTR Load Error: {e}")
281
+ # raise e
282
+
283
+ # # B. Load LLM (Qwen2.5-7B-Instruct)
284
+ # # With 50GB RAM, we can load this comfortably.
285
+ # # If it is too slow, change MODEL_ID to "Qwen/Qwen2.5-3B-Instruct" or "Qwen/Qwen2.5-1.5B-Instruct"
286
+ # MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
287
+
288
+ # try:
289
+ # print(f"⬇️ Downloading & Loading {MODEL_ID}...")
290
+ # tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
291
+ # llm_model = AutoModelForCausalLM.from_pretrained(
292
+ # MODEL_ID,
293
+ # torch_dtype="auto",
294
+ # device_map="cpu" # Uses your 50GB System RAM
295
+ # )
296
+ # print(f"✅ {MODEL_ID} loaded successfully.")
297
+ # except Exception as e:
298
+ # print(f"❌ LLM Load Error: {e}")
299
+ # llm_model = None
300
+ # tokenizer = None
301
+
302
+ # # ------------------------------------------------------
303
+ # # 2. Correction Logic (The "Smart" Fix)
304
+ # # ------------------------------------------------------
305
+ # def smart_correction(text):
306
+ # if not text or not llm_model:
307
+ # return text
308
+
309
+ # print("--- Starting AI Correction ---")
310
+
311
+ # # 1. Construct the Prompt
312
+ # # We ask the model to act as a text editor.
313
+ # system_prompt = "You are a helpful assistant that corrects OCR text. Fix typos, capitalization, and grammar. Maintain the original line structure. Do not add any conversational text like 'Here is the corrected text'."
314
+ # user_prompt = f"Correct the following OCR text:\n\n{text}"
315
+
316
+ # messages = [
317
+ # {"role": "system", "content": system_prompt},
318
+ # {"role": "user", "content": user_prompt}
319
+ # ]
320
+
321
+ # text_input = tokenizer.apply_chat_template(
322
+ # messages,
323
+ # tokenize=False,
324
+ # add_generation_prompt=True
325
+ # )
326
+
327
+ # model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu")
328
+
329
+ # # 2. Run Inference
330
+ # # max_new_tokens limits the output length to avoid infinite loops
331
+ # generated_ids = llm_model.generate(
332
+ # model_inputs.input_ids,
333
+ # max_new_tokens=1024,
334
+ # temperature=0.1, # Low temp for factual/consistent results
335
+ # do_sample=False # Greedy decoding is faster and more deterministic
336
+ # )
337
+
338
+ # # 3. Decode Output
339
+ # # We strip the input tokens to get only the new (corrected) text
340
+ # generated_ids = [
341
+ # output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
342
+ # ]
343
+ # response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
344
+
345
+ # return response
346
+
347
+ # # ------------------------------------------------------
348
+ # # 3. Processing Pipeline
349
+ # # ------------------------------------------------------
350
+ # def run_ocr(input_image):
351
+ # tmp_path = None
352
+ # try:
353
+ # if input_image is None:
354
+ # return None, "No image uploaded", None, None
355
+
356
+ # # Robust Temp File Handling
357
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
358
+ # input_image.save(tmp.name)
359
+ # tmp_path = tmp.name
360
+
361
+ # # 1. Run OCR
362
+ # doc = DocumentFile.from_images(tmp_path)
363
+ # result = ocr_model(doc)
364
+ # raw_text = result.render()
365
+
366
+ # # 2. Run AI Correction
367
+ # # We pass the WHOLE text block at once. Context helps the AI.
368
+ # corrected_text = smart_correction(raw_text)
369
+
370
+ # # 3. Visualization
371
+ # image_np = np.array(input_image)
372
+ # viz_image = image_np.copy()
373
+
374
+ # for page in result.pages:
375
+ # for block in page.blocks:
376
+ # for line in block.lines:
377
+ # for word in line.words:
378
+ # h, w = viz_image.shape[:2]
379
+ # (x_min, y_min), (x_max, y_max) = word.geometry
380
+ # x1, y1 = int(x_min * w), int(y_min * h)
381
+ # x2, y2 = int(x_max * w), int(y_max * h)
382
+ # cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
383
+
384
+ # return viz_image, raw_text, corrected_text, result.export()
385
+
386
+ # except Exception as e:
387
+ # error_log = traceback.format_exc()
388
+ # return None, f"Error: {e}", f"Logs:\n{error_log}", {"error": str(e)}
389
+
390
+ # finally:
391
+ # if tmp_path and os.path.exists(tmp_path):
392
+ # os.remove(tmp_path)
393
+
394
+ # # ------------------------------------------------------
395
+ # # 4. Gradio Interface
396
+ # # ------------------------------------------------------
397
+ # with gr.Blocks(title="Next-Gen OCR") as demo:
398
+ # gr.Markdown("## 📄 Next-Gen AI OCR")
399
+ # gr.Markdown(f"Using **DocTR** for extraction and **{MODEL_ID}** for smart correction.")
400
+
401
+ # with gr.Row():
402
+ # input_img = gr.Image(type="pil", label="Upload Document")
403
+
404
+ # with gr.Row():
405
+ # btn = gr.Button("Run Extraction & Smart Correction", variant="primary")
406
+
407
+ # with gr.Row():
408
+ # out_img = gr.Image(label="Detections")
409
+
410
+ # with gr.Row():
411
+ # out_raw = gr.Textbox(label="Raw OCR Output", lines=10)
412
+ # out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 7B)", lines=10)
413
+
414
+ # with gr.Row():
415
+ # out_json = gr.JSON(label="JSON Data")
416
+
417
+ # btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_raw, out_corrected, out_json])
418
+
419
+ # if __name__ == "__main__":
420
+ # demo.launch()
421
+
422
+
423
+
424
+
425
+
426
+
427
+
428
+
429
+
430
+
431
+
432
+
433
+
434
+
435
+
436
+
437
  import gradio as gr
438
  import numpy as np
439
  import cv2
 
458
  print(f"❌ DocTR Load Error: {e}")
459
  raise e
460
 
461
+ # B. Load LLM (Qwen2.5-3B-Instruct)
462
+ # 3B fits easily in 18GB RAM (takes ~6GB) allowing space for OS + OCR.
 
463
  MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
464
 
465
  try:
 
468
  llm_model = AutoModelForCausalLM.from_pretrained(
469
  MODEL_ID,
470
  torch_dtype="auto",
471
+ device_map="cpu" # Efficiently uses RAM
472
  )
473
  print(f"✅ {MODEL_ID} loaded successfully.")
474
  except Exception as e:
 
477
  tokenizer = None
478
 
479
  # ------------------------------------------------------
480
+ # 2. Correction Logic (Context-Aware)
481
  # ------------------------------------------------------
482
  def smart_correction(text):
483
  if not text or not llm_model:
 
486
  print("--- Starting AI Correction ---")
487
 
488
  # 1. Construct the Prompt
489
+ # We explicitly tell it to fix OCR errors and maintain structure.
490
+ system_prompt = (
491
+ "You are an expert OCR post-processing assistant. "
492
+ "Your task is to correct OCR errors, typos, and grammar in the provided text. "
493
+ "Maintain the original line breaks and layout strictly. "
494
+ "Do not add any conversational text. Output ONLY the corrected text."
495
+ )
496
  user_prompt = f"Correct the following OCR text:\n\n{text}"
497
 
498
  messages = [
 
500
  {"role": "user", "content": user_prompt}
501
  ]
502
 
503
+ # Apply chat template
504
  text_input = tokenizer.apply_chat_template(
505
  messages,
506
  tokenize=False,
507
  add_generation_prompt=True
508
  )
509
 
510
+ # Tokenize
511
  model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu")
512
 
513
  # 2. Run Inference
514
+ # Greedy decoding (do_sample=False) is faster and prevents "creative" hallucinations.
515
+ try:
516
+ generated_ids = llm_model.generate(
517
+ model_inputs.input_ids,
518
+ max_new_tokens=1024,
519
+ temperature=0.1,
520
+ do_sample=False
521
+ )
522
+
523
+ # 3. Decode Output
524
+ # Strip input tokens to get only the new text
525
+ generated_ids = [
526
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
527
+ ]
528
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
529
+ return response
530
+
531
+ except Exception as e:
532
+ print(f"Inference Error: {e}")
533
+ return text # Fallback to original if AI fails
534
 
535
  # ------------------------------------------------------
536
  # 3. Processing Pipeline
 
541
  if input_image is None:
542
  return None, "No image uploaded", None, None
543
 
544
+ # Temp file for robust loading
545
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
546
  input_image.save(tmp.name)
547
  tmp_path = tmp.name
 
552
  raw_text = result.render()
553
 
554
  # 2. Run AI Correction
555
+ # The 3B model is fast enough to handle the full page context at once.
556
  corrected_text = smart_correction(raw_text)
557
 
558
  # 3. Visualization
 
582
  # ------------------------------------------------------
583
  # 4. Gradio Interface
584
  # ------------------------------------------------------
585
+ with gr.Blocks(title="AI OCR with Qwen 3B") as demo:
586
+ gr.Markdown("## 📄 Robust AI OCR")
587
+ gr.Markdown(f"Using **DocTR** for text extraction and **{MODEL_ID}** for intelligent grammar correction.")
588
 
589
  with gr.Row():
590
  input_img = gr.Image(type="pil", label="Upload Document")
 
597
 
598
  with gr.Row():
599
  out_raw = gr.Textbox(label="Raw OCR Output", lines=10)
600
+ out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 3B)", lines=10)
601
 
602
  with gr.Row():
603
  out_json = gr.JSON(label="JSON Data")