Update app.py
Browse files
app.py
CHANGED
|
@@ -256,6 +256,184 @@
|
|
| 256 |
|
| 257 |
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
import gradio as gr
|
| 260 |
import numpy as np
|
| 261 |
import cv2
|
|
@@ -280,9 +458,8 @@ except Exception as e:
|
|
| 280 |
print(f"❌ DocTR Load Error: {e}")
|
| 281 |
raise e
|
| 282 |
|
| 283 |
-
# B. Load LLM (Qwen2.5-
|
| 284 |
-
#
|
| 285 |
-
# If it is too slow, change MODEL_ID to "Qwen/Qwen2.5-3B-Instruct" or "Qwen/Qwen2.5-1.5B-Instruct"
|
| 286 |
MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 287 |
|
| 288 |
try:
|
|
@@ -291,7 +468,7 @@ try:
|
|
| 291 |
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 292 |
MODEL_ID,
|
| 293 |
torch_dtype="auto",
|
| 294 |
-
device_map="cpu" #
|
| 295 |
)
|
| 296 |
print(f"✅ {MODEL_ID} loaded successfully.")
|
| 297 |
except Exception as e:
|
|
@@ -300,7 +477,7 @@ except Exception as e:
|
|
| 300 |
tokenizer = None
|
| 301 |
|
| 302 |
# ------------------------------------------------------
|
| 303 |
-
# 2. Correction Logic (
|
| 304 |
# ------------------------------------------------------
|
| 305 |
def smart_correction(text):
|
| 306 |
if not text or not llm_model:
|
|
@@ -309,8 +486,13 @@ def smart_correction(text):
|
|
| 309 |
print("--- Starting AI Correction ---")
|
| 310 |
|
| 311 |
# 1. Construct the Prompt
|
| 312 |
-
# We
|
| 313 |
-
system_prompt =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
user_prompt = f"Correct the following OCR text:\n\n{text}"
|
| 315 |
|
| 316 |
messages = [
|
|
@@ -318,31 +500,37 @@ def smart_correction(text):
|
|
| 318 |
{"role": "user", "content": user_prompt}
|
| 319 |
]
|
| 320 |
|
|
|
|
| 321 |
text_input = tokenizer.apply_chat_template(
|
| 322 |
messages,
|
| 323 |
tokenize=False,
|
| 324 |
add_generation_prompt=True
|
| 325 |
)
|
| 326 |
|
|
|
|
| 327 |
model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu")
|
| 328 |
|
| 329 |
# 2. Run Inference
|
| 330 |
-
#
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
# ------------------------------------------------------
|
| 348 |
# 3. Processing Pipeline
|
|
@@ -353,7 +541,7 @@ def run_ocr(input_image):
|
|
| 353 |
if input_image is None:
|
| 354 |
return None, "No image uploaded", None, None
|
| 355 |
|
| 356 |
-
#
|
| 357 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
|
| 358 |
input_image.save(tmp.name)
|
| 359 |
tmp_path = tmp.name
|
|
@@ -364,7 +552,7 @@ def run_ocr(input_image):
|
|
| 364 |
raw_text = result.render()
|
| 365 |
|
| 366 |
# 2. Run AI Correction
|
| 367 |
-
#
|
| 368 |
corrected_text = smart_correction(raw_text)
|
| 369 |
|
| 370 |
# 3. Visualization
|
|
@@ -394,9 +582,9 @@ def run_ocr(input_image):
|
|
| 394 |
# ------------------------------------------------------
|
| 395 |
# 4. Gradio Interface
|
| 396 |
# ------------------------------------------------------
|
| 397 |
-
with gr.Blocks(title="
|
| 398 |
-
gr.Markdown("## 📄
|
| 399 |
-
gr.Markdown(f"Using **DocTR** for extraction and **{MODEL_ID}** for
|
| 400 |
|
| 401 |
with gr.Row():
|
| 402 |
input_img = gr.Image(type="pil", label="Upload Document")
|
|
@@ -409,7 +597,7 @@ with gr.Blocks(title="Next-Gen OCR") as demo:
|
|
| 409 |
|
| 410 |
with gr.Row():
|
| 411 |
out_raw = gr.Textbox(label="Raw OCR Output", lines=10)
|
| 412 |
-
out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen
|
| 413 |
|
| 414 |
with gr.Row():
|
| 415 |
out_json = gr.JSON(label="JSON Data")
|
|
|
|
| 256 |
|
| 257 |
|
| 258 |
|
| 259 |
+
# import gradio as gr
|
| 260 |
+
# import numpy as np
|
| 261 |
+
# import cv2
|
| 262 |
+
# import traceback
|
| 263 |
+
# import tempfile
|
| 264 |
+
# import os
|
| 265 |
+
# import torch
|
| 266 |
+
# from doctr.io import DocumentFile
|
| 267 |
+
# from doctr.models import ocr_predictor
|
| 268 |
+
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 269 |
+
|
| 270 |
+
# # ------------------------------------------------------
|
| 271 |
+
# # 1. Configuration & Global Loading
|
| 272 |
+
# # ------------------------------------------------------
|
| 273 |
+
# print("⏳ Loading models...")
|
| 274 |
+
|
| 275 |
+
# # A. Load DocTR (OCR)
|
| 276 |
+
# try:
|
| 277 |
+
# ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True)
|
| 278 |
+
# print("✅ DocTR loaded.")
|
| 279 |
+
# except Exception as e:
|
| 280 |
+
# print(f"❌ DocTR Load Error: {e}")
|
| 281 |
+
# raise e
|
| 282 |
+
|
| 283 |
+
# # B. Load LLM (Qwen2.5-7B-Instruct)
|
| 284 |
+
# # With 50GB RAM, we can load this comfortably.
|
| 285 |
+
# # If it is too slow, change MODEL_ID to "Qwen/Qwen2.5-3B-Instruct" or "Qwen/Qwen2.5-1.5B-Instruct"
|
| 286 |
+
# MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 287 |
+
|
| 288 |
+
# try:
|
| 289 |
+
# print(f"⬇️ Downloading & Loading {MODEL_ID}...")
|
| 290 |
+
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 291 |
+
# llm_model = AutoModelForCausalLM.from_pretrained(
|
| 292 |
+
# MODEL_ID,
|
| 293 |
+
# torch_dtype="auto",
|
| 294 |
+
# device_map="cpu" # Uses your 50GB System RAM
|
| 295 |
+
# )
|
| 296 |
+
# print(f"✅ {MODEL_ID} loaded successfully.")
|
| 297 |
+
# except Exception as e:
|
| 298 |
+
# print(f"❌ LLM Load Error: {e}")
|
| 299 |
+
# llm_model = None
|
| 300 |
+
# tokenizer = None
|
| 301 |
+
|
| 302 |
+
# # ------------------------------------------------------
|
| 303 |
+
# # 2. Correction Logic (The "Smart" Fix)
|
| 304 |
+
# # ------------------------------------------------------
|
| 305 |
+
# def smart_correction(text):
|
| 306 |
+
# if not text or not llm_model:
|
| 307 |
+
# return text
|
| 308 |
+
|
| 309 |
+
# print("--- Starting AI Correction ---")
|
| 310 |
+
|
| 311 |
+
# # 1. Construct the Prompt
|
| 312 |
+
# # We ask the model to act as a text editor.
|
| 313 |
+
# system_prompt = "You are a helpful assistant that corrects OCR text. Fix typos, capitalization, and grammar. Maintain the original line structure. Do not add any conversational text like 'Here is the corrected text'."
|
| 314 |
+
# user_prompt = f"Correct the following OCR text:\n\n{text}"
|
| 315 |
+
|
| 316 |
+
# messages = [
|
| 317 |
+
# {"role": "system", "content": system_prompt},
|
| 318 |
+
# {"role": "user", "content": user_prompt}
|
| 319 |
+
# ]
|
| 320 |
+
|
| 321 |
+
# text_input = tokenizer.apply_chat_template(
|
| 322 |
+
# messages,
|
| 323 |
+
# tokenize=False,
|
| 324 |
+
# add_generation_prompt=True
|
| 325 |
+
# )
|
| 326 |
+
|
| 327 |
+
# model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu")
|
| 328 |
+
|
| 329 |
+
# # 2. Run Inference
|
| 330 |
+
# # max_new_tokens limits the output length to avoid infinite loops
|
| 331 |
+
# generated_ids = llm_model.generate(
|
| 332 |
+
# model_inputs.input_ids,
|
| 333 |
+
# max_new_tokens=1024,
|
| 334 |
+
# temperature=0.1, # Low temp for factual/consistent results
|
| 335 |
+
# do_sample=False # Greedy decoding is faster and more deterministic
|
| 336 |
+
# )
|
| 337 |
+
|
| 338 |
+
# # 3. Decode Output
|
| 339 |
+
# # We strip the input tokens to get only the new (corrected) text
|
| 340 |
+
# generated_ids = [
|
| 341 |
+
# output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 342 |
+
# ]
|
| 343 |
+
# response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 344 |
+
|
| 345 |
+
# return response
|
| 346 |
+
|
| 347 |
+
# # ------------------------------------------------------
|
| 348 |
+
# # 3. Processing Pipeline
|
| 349 |
+
# # ------------------------------------------------------
|
| 350 |
+
# def run_ocr(input_image):
|
| 351 |
+
# tmp_path = None
|
| 352 |
+
# try:
|
| 353 |
+
# if input_image is None:
|
| 354 |
+
# return None, "No image uploaded", None, None
|
| 355 |
+
|
| 356 |
+
# # Robust Temp File Handling
|
| 357 |
+
# with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
|
| 358 |
+
# input_image.save(tmp.name)
|
| 359 |
+
# tmp_path = tmp.name
|
| 360 |
+
|
| 361 |
+
# # 1. Run OCR
|
| 362 |
+
# doc = DocumentFile.from_images(tmp_path)
|
| 363 |
+
# result = ocr_model(doc)
|
| 364 |
+
# raw_text = result.render()
|
| 365 |
+
|
| 366 |
+
# # 2. Run AI Correction
|
| 367 |
+
# # We pass the WHOLE text block at once. Context helps the AI.
|
| 368 |
+
# corrected_text = smart_correction(raw_text)
|
| 369 |
+
|
| 370 |
+
# # 3. Visualization
|
| 371 |
+
# image_np = np.array(input_image)
|
| 372 |
+
# viz_image = image_np.copy()
|
| 373 |
+
|
| 374 |
+
# for page in result.pages:
|
| 375 |
+
# for block in page.blocks:
|
| 376 |
+
# for line in block.lines:
|
| 377 |
+
# for word in line.words:
|
| 378 |
+
# h, w = viz_image.shape[:2]
|
| 379 |
+
# (x_min, y_min), (x_max, y_max) = word.geometry
|
| 380 |
+
# x1, y1 = int(x_min * w), int(y_min * h)
|
| 381 |
+
# x2, y2 = int(x_max * w), int(y_max * h)
|
| 382 |
+
# cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 383 |
+
|
| 384 |
+
# return viz_image, raw_text, corrected_text, result.export()
|
| 385 |
+
|
| 386 |
+
# except Exception as e:
|
| 387 |
+
# error_log = traceback.format_exc()
|
| 388 |
+
# return None, f"Error: {e}", f"Logs:\n{error_log}", {"error": str(e)}
|
| 389 |
+
|
| 390 |
+
# finally:
|
| 391 |
+
# if tmp_path and os.path.exists(tmp_path):
|
| 392 |
+
# os.remove(tmp_path)
|
| 393 |
+
|
| 394 |
+
# # ------------------------------------------------------
|
| 395 |
+
# # 4. Gradio Interface
|
| 396 |
+
# # ------------------------------------------------------
|
| 397 |
+
# with gr.Blocks(title="Next-Gen OCR") as demo:
|
| 398 |
+
# gr.Markdown("## 📄 Next-Gen AI OCR")
|
| 399 |
+
# gr.Markdown(f"Using **DocTR** for extraction and **{MODEL_ID}** for smart correction.")
|
| 400 |
+
|
| 401 |
+
# with gr.Row():
|
| 402 |
+
# input_img = gr.Image(type="pil", label="Upload Document")
|
| 403 |
+
|
| 404 |
+
# with gr.Row():
|
| 405 |
+
# btn = gr.Button("Run Extraction & Smart Correction", variant="primary")
|
| 406 |
+
|
| 407 |
+
# with gr.Row():
|
| 408 |
+
# out_img = gr.Image(label="Detections")
|
| 409 |
+
|
| 410 |
+
# with gr.Row():
|
| 411 |
+
# out_raw = gr.Textbox(label="Raw OCR Output", lines=10)
|
| 412 |
+
# out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 7B)", lines=10)
|
| 413 |
+
|
| 414 |
+
# with gr.Row():
|
| 415 |
+
# out_json = gr.JSON(label="JSON Data")
|
| 416 |
+
|
| 417 |
+
# btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_raw, out_corrected, out_json])
|
| 418 |
+
|
| 419 |
+
# if __name__ == "__main__":
|
| 420 |
+
# demo.launch()
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
|
| 437 |
import gradio as gr
|
| 438 |
import numpy as np
|
| 439 |
import cv2
|
|
|
|
| 458 |
print(f"❌ DocTR Load Error: {e}")
|
| 459 |
raise e
|
| 460 |
|
| 461 |
+
# B. Load LLM (Qwen2.5-3B-Instruct)
|
| 462 |
+
# 3B fits easily in 18GB RAM (takes ~6GB) allowing space for OS + OCR.
|
|
|
|
| 463 |
MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 464 |
|
| 465 |
try:
|
|
|
|
| 468 |
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 469 |
MODEL_ID,
|
| 470 |
torch_dtype="auto",
|
| 471 |
+
device_map="cpu" # Efficiently uses RAM
|
| 472 |
)
|
| 473 |
print(f"✅ {MODEL_ID} loaded successfully.")
|
| 474 |
except Exception as e:
|
|
|
|
| 477 |
tokenizer = None
|
| 478 |
|
| 479 |
# ------------------------------------------------------
|
| 480 |
+
# 2. Correction Logic (Context-Aware)
|
| 481 |
# ------------------------------------------------------
|
| 482 |
def smart_correction(text):
|
| 483 |
if not text or not llm_model:
|
|
|
|
| 486 |
print("--- Starting AI Correction ---")
|
| 487 |
|
| 488 |
# 1. Construct the Prompt
|
| 489 |
+
# We explicitly tell it to fix OCR errors and maintain structure.
|
| 490 |
+
system_prompt = (
|
| 491 |
+
"You are an expert OCR post-processing assistant. "
|
| 492 |
+
"Your task is to correct OCR errors, typos, and grammar in the provided text. "
|
| 493 |
+
"Maintain the original line breaks and layout strictly. "
|
| 494 |
+
"Do not add any conversational text. Output ONLY the corrected text."
|
| 495 |
+
)
|
| 496 |
user_prompt = f"Correct the following OCR text:\n\n{text}"
|
| 497 |
|
| 498 |
messages = [
|
|
|
|
| 500 |
{"role": "user", "content": user_prompt}
|
| 501 |
]
|
| 502 |
|
| 503 |
+
# Apply chat template
|
| 504 |
text_input = tokenizer.apply_chat_template(
|
| 505 |
messages,
|
| 506 |
tokenize=False,
|
| 507 |
add_generation_prompt=True
|
| 508 |
)
|
| 509 |
|
| 510 |
+
# Tokenize
|
| 511 |
model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu")
|
| 512 |
|
| 513 |
# 2. Run Inference
|
| 514 |
+
# Greedy decoding (do_sample=False) is faster and prevents "creative" hallucinations.
|
| 515 |
+
try:
|
| 516 |
+
generated_ids = llm_model.generate(
|
| 517 |
+
model_inputs.input_ids,
|
| 518 |
+
max_new_tokens=1024,
|
| 519 |
+
temperature=0.1,
|
| 520 |
+
do_sample=False
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# 3. Decode Output
|
| 524 |
+
# Strip input tokens to get only the new text
|
| 525 |
+
generated_ids = [
|
| 526 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 527 |
+
]
|
| 528 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 529 |
+
return response
|
| 530 |
+
|
| 531 |
+
except Exception as e:
|
| 532 |
+
print(f"Inference Error: {e}")
|
| 533 |
+
return text # Fallback to original if AI fails
|
| 534 |
|
| 535 |
# ------------------------------------------------------
|
| 536 |
# 3. Processing Pipeline
|
|
|
|
| 541 |
if input_image is None:
|
| 542 |
return None, "No image uploaded", None, None
|
| 543 |
|
| 544 |
+
# Temp file for robust loading
|
| 545 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
|
| 546 |
input_image.save(tmp.name)
|
| 547 |
tmp_path = tmp.name
|
|
|
|
| 552 |
raw_text = result.render()
|
| 553 |
|
| 554 |
# 2. Run AI Correction
|
| 555 |
+
# The 3B model is fast enough to handle the full page context at once.
|
| 556 |
corrected_text = smart_correction(raw_text)
|
| 557 |
|
| 558 |
# 3. Visualization
|
|
|
|
| 582 |
# ------------------------------------------------------
|
| 583 |
# 4. Gradio Interface
|
| 584 |
# ------------------------------------------------------
|
| 585 |
+
with gr.Blocks(title="AI OCR with Qwen 3B") as demo:
|
| 586 |
+
gr.Markdown("## 📄 Robust AI OCR")
|
| 587 |
+
gr.Markdown(f"Using **DocTR** for text extraction and **{MODEL_ID}** for intelligent grammar correction.")
|
| 588 |
|
| 589 |
with gr.Row():
|
| 590 |
input_img = gr.Image(type="pil", label="Upload Document")
|
|
|
|
| 597 |
|
| 598 |
with gr.Row():
|
| 599 |
out_raw = gr.Textbox(label="Raw OCR Output", lines=10)
|
| 600 |
+
out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 3B)", lines=10)
|
| 601 |
|
| 602 |
with gr.Row():
|
| 603 |
out_json = gr.JSON(label="JSON Data")
|