docTR / app.py
iammraat's picture
Update app.py
79ff71f verified
# import gradio as gr
# import numpy as np
# import cv2
# import traceback
# import tempfile
# import os
# from PIL import Image
# from doctr.io import DocumentFile
# from doctr.models import ocr_predictor
# # 1. Load the model globally
# print("Loading DocTR model...")
# try:
# # Using a lighter model 'fast_base' to prevent memory crashes on free tier
# # You can switch back to 'db_resnet50' if you have a GPU or more RAM
# model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True)
# except Exception as e:
# print(f"Model Load Error: {e}")
# raise e
# def run_ocr(input_image):
# tmp_path = None
# try:
# if input_image is None:
# return None, "No image uploaded", None
# # 2. ROBUST FIX: Save image to a temporary file first
# # This forces DocTR to read it as a file, which always works.
# with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
# input_image.save(tmp.name)
# tmp_path = tmp.name
# # 3. Run OCR on the temporary file path
# doc = DocumentFile.from_images(tmp_path)
# result = model(doc)
# # 4. Visualization Prep
# # Convert PIL to numpy for drawing boxes (OpenCV uses BGR, PIL uses RGB)
# image_np = np.array(input_image)
# viz_image = image_np.copy()
# full_text = result.render()
# # 5. Draw Boxes
# for page in result.pages:
# for block in page.blocks:
# for line in block.lines:
# for word in line.words:
# h, w = viz_image.shape[:2]
# (x_min, y_min), (x_max, y_max) = word.geometry
# x1, y1 = int(x_min * w), int(y_min * h)
# x2, y2 = int(x_max * w), int(y_max * h)
# # Draw Green Box
# cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# return viz_image, full_text, result.export()
# except Exception as e:
# error_log = traceback.format_exc()
# return None, f"❌ ERROR LOG:\n\n{error_log}", {"error": str(e)}
# finally:
# # Cleanup the temp file
# if tmp_path and os.path.exists(tmp_path):
# os.remove(tmp_path)
# # Gradio UI
# with gr.Blocks(title="DocTR OCR Demo") as demo:
# gr.Markdown("## 📄 DocTR OCR (Robust Mode)")
# with gr.Row():
# input_img = gr.Image(type="pil", label="Upload Document")
# with gr.Row():
# btn = gr.Button("Run OCR", variant="primary")
# with gr.Row():
# out_img = gr.Image(label="Detections")
# out_text = gr.Textbox(label="Extracted Text", lines=10)
# out_json = gr.JSON(label="JSON Output")
# btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_text, out_json])
# if __name__ == "__main__":
# demo.launch()
# import gradio as gr
# import numpy as np
# import cv2
# import traceback
# import tempfile
# import os
# from PIL import Image
# from doctr.io import DocumentFile
# from doctr.models import ocr_predictor
# from transformers import pipeline
# # ------------------------------------------------------
# # 1. Load Models Globally
# # ------------------------------------------------------
# print("⏳ Loading models...")
# # A. Load DocTR (OCR)
# try:
# # 'fast_base' is lightweight for CPU
# ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True)
# print("✅ DocTR loaded.")
# except Exception as e:
# print(f"❌ DocTR Load Error: {e}")
# raise e
# # B. Load Corrector (Small Language Model)
# try:
# # 'google/flan-t5-small' is ~250MB, well under the 1GB limit.
# # We use a text2text-generation pipeline.
# corrector = pipeline(
# "text2text-generation",
# model="google/flan-t5-small",
# device=-1 # -1 forces CPU
# )
# print("✅ Correction model (Flan-T5-Small) loaded.")
# except Exception as e:
# print(f"❌ Corrector Load Error: {e}")
# corrector = None
# # ------------------------------------------------------
# # 2. Correction Logic
# # ------------------------------------------------------
# def smart_correction(text):
# if not text or not text.strip() or corrector is None:
# return text
# # DocTR returns text with newlines. LLMs often prefer line-by-line or chunked input
# # if the context isn't massive. For a small model, processing line-by-line is safer.
# lines = text.split('\n')
# corrected_lines = []
# print("--- Starting Correction ---")
# for line in lines:
# if len(line.strip()) < 3: # Skip empty/tiny lines
# corrected_lines.append(line)
# continue
# try:
# # Prompt engineering for Flan-T5
# prompt = f"Fix grammar and OCR errors: {line}"
# # max_length ensures it doesn't ramble.
# result = corrector(prompt, max_length=128)
# fixed_text = result[0]['generated_text']
# # Fallback: if model returns empty, keep original
# corrected_lines.append(fixed_text if fixed_text else line)
# except Exception as e:
# print(f"Correction failed for line '{line}': {e}")
# corrected_lines.append(line)
# return "\n".join(corrected_lines)
# # ------------------------------------------------------
# # 3. Main Processing Function
# # ------------------------------------------------------
# def run_ocr(input_image):
# tmp_path = None
# try:
# if input_image is None:
# return None, "No image uploaded", None, None
# # -- Save temp file for DocTR robustness --
# with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
# input_image.save(tmp.name)
# tmp_path = tmp.name
# # -- Run OCR --
# doc = DocumentFile.from_images(tmp_path)
# result = ocr_model(doc)
# # -- Raw Text --
# raw_text = result.render()
# # -- Correction Step --
# corrected_text = smart_correction(raw_text)
# # -- Visualization --
# image_np = np.array(input_image)
# viz_image = image_np.copy()
# for page in result.pages:
# for block in page.blocks:
# for line in block.lines:
# for word in line.words:
# h, w = viz_image.shape[:2]
# (x_min, y_min), (x_max, y_max) = word.geometry
# x1, y1 = int(x_min * w), int(y_min * h)
# x2, y2 = int(x_max * w), int(y_max * h)
# cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# return viz_image, raw_text, corrected_text, result.export()
# except Exception as e:
# error_log = traceback.format_exc()
# return None, f"Error: {e}", f"Error Log:\n{error_log}", {"error": str(e)}
# finally:
# if tmp_path and os.path.exists(tmp_path):
# os.remove(tmp_path)
# # ------------------------------------------------------
# # 4. Gradio UI
# # ------------------------------------------------------
# with gr.Blocks(title="DocTR OCR + Correction") as demo:
# gr.Markdown("## 📄 AI OCR with Grammar Correction")
# gr.Markdown("Using `DocTR` for extraction and `Flan-T5-Small` for correction.")
# with gr.Row():
# input_img = gr.Image(type="pil", label="Upload Document")
# with gr.Row():
# btn = gr.Button("Run Extraction & Correction", variant="primary")
# with gr.Row():
# out_img = gr.Image(label="Detections")
# with gr.Row():
# out_raw = gr.Textbox(label="Raw OCR Text", lines=8, placeholder="Raw output appears here...")
# out_corrected = gr.Textbox(label="✨ Corrected Text", lines=8, placeholder="AI corrected output appears here...")
# with gr.Row():
# out_json = gr.JSON(label="Full JSON Data")
# btn.click(
# fn=run_ocr,
# inputs=input_img,
# outputs=[out_img, out_raw, out_corrected, out_json]
# )
# if __name__ == "__main__":
# demo.launch()
# import gradio as gr
# import numpy as np
# import cv2
# import traceback
# import tempfile
# import os
# import torch
# from doctr.io import DocumentFile
# from doctr.models import ocr_predictor
# from transformers import AutoModelForCausalLM, AutoTokenizer
# # ------------------------------------------------------
# # 1. Configuration & Global Loading
# # ------------------------------------------------------
# print("⏳ Loading models...")
# # A. Load DocTR (OCR)
# try:
# ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True)
# print("✅ DocTR loaded.")
# except Exception as e:
# print(f"❌ DocTR Load Error: {e}")
# raise e
# # B. Load LLM (Qwen2.5-7B-Instruct)
# # With 50GB RAM, we can load this comfortably.
# # If it is too slow, change MODEL_ID to "Qwen/Qwen2.5-3B-Instruct" or "Qwen/Qwen2.5-1.5B-Instruct"
# MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
# try:
# print(f"⬇️ Downloading & Loading {MODEL_ID}...")
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# llm_model = AutoModelForCausalLM.from_pretrained(
# MODEL_ID,
# torch_dtype="auto",
# device_map="cpu" # Uses your 50GB System RAM
# )
# print(f"✅ {MODEL_ID} loaded successfully.")
# except Exception as e:
# print(f"❌ LLM Load Error: {e}")
# llm_model = None
# tokenizer = None
# # ------------------------------------------------------
# # 2. Correction Logic (The "Smart" Fix)
# # ------------------------------------------------------
# def smart_correction(text):
# if not text or not llm_model:
# return text
# print("--- Starting AI Correction ---")
# # 1. Construct the Prompt
# # We ask the model to act as a text editor.
# system_prompt = "You are a helpful assistant that corrects OCR text. Fix typos, capitalization, and grammar. Maintain the original line structure. Do not add any conversational text like 'Here is the corrected text'."
# user_prompt = f"Correct the following OCR text:\n\n{text}"
# messages = [
# {"role": "system", "content": system_prompt},
# {"role": "user", "content": user_prompt}
# ]
# text_input = tokenizer.apply_chat_template(
# messages,
# tokenize=False,
# add_generation_prompt=True
# )
# model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu")
# # 2. Run Inference
# # max_new_tokens limits the output length to avoid infinite loops
# generated_ids = llm_model.generate(
# model_inputs.input_ids,
# max_new_tokens=1024,
# temperature=0.1, # Low temp for factual/consistent results
# do_sample=False # Greedy decoding is faster and more deterministic
# )
# # 3. Decode Output
# # We strip the input tokens to get only the new (corrected) text
# generated_ids = [
# output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
# ]
# response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return response
# # ------------------------------------------------------
# # 3. Processing Pipeline
# # ------------------------------------------------------
# def run_ocr(input_image):
# tmp_path = None
# try:
# if input_image is None:
# return None, "No image uploaded", None, None
# # Robust Temp File Handling
# with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
# input_image.save(tmp.name)
# tmp_path = tmp.name
# # 1. Run OCR
# doc = DocumentFile.from_images(tmp_path)
# result = ocr_model(doc)
# raw_text = result.render()
# # 2. Run AI Correction
# # We pass the WHOLE text block at once. Context helps the AI.
# corrected_text = smart_correction(raw_text)
# # 3. Visualization
# image_np = np.array(input_image)
# viz_image = image_np.copy()
# for page in result.pages:
# for block in page.blocks:
# for line in block.lines:
# for word in line.words:
# h, w = viz_image.shape[:2]
# (x_min, y_min), (x_max, y_max) = word.geometry
# x1, y1 = int(x_min * w), int(y_min * h)
# x2, y2 = int(x_max * w), int(y_max * h)
# cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# return viz_image, raw_text, corrected_text, result.export()
# except Exception as e:
# error_log = traceback.format_exc()
# return None, f"Error: {e}", f"Logs:\n{error_log}", {"error": str(e)}
# finally:
# if tmp_path and os.path.exists(tmp_path):
# os.remove(tmp_path)
# # ------------------------------------------------------
# # 4. Gradio Interface
# # ------------------------------------------------------
# with gr.Blocks(title="Next-Gen OCR") as demo:
# gr.Markdown("## 📄 Next-Gen AI OCR")
# gr.Markdown(f"Using **DocTR** for extraction and **{MODEL_ID}** for smart correction.")
# with gr.Row():
# input_img = gr.Image(type="pil", label="Upload Document")
# with gr.Row():
# btn = gr.Button("Run Extraction & Smart Correction", variant="primary")
# with gr.Row():
# out_img = gr.Image(label="Detections")
# with gr.Row():
# out_raw = gr.Textbox(label="Raw OCR Output", lines=10)
# out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 7B)", lines=10)
# with gr.Row():
# out_json = gr.JSON(label="JSON Data")
# btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_raw, out_corrected, out_json])
# if __name__ == "__main__":
# demo.launch()
import gradio as gr
import numpy as np
import cv2
import traceback
import tempfile
import os
import torch
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from transformers import AutoModelForCausalLM, AutoTokenizer
# ------------------------------------------------------
# 1. Configuration & Global Loading
# ------------------------------------------------------
print("⏳ Loading models...")
# A. Load DocTR (OCR)
try:
ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True)
print("✅ DocTR loaded.")
except Exception as e:
print(f"❌ DocTR Load Error: {e}")
raise e
# B. Load LLM (Qwen2.5-3B-Instruct)
# 3B fits easily in 18GB RAM (takes ~6GB) allowing space for OS + OCR.
MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
try:
print(f"⬇️ Downloading & Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
llm_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="cpu" # Efficiently uses RAM
)
print(f"✅ {MODEL_ID} loaded successfully.")
except Exception as e:
print(f"❌ LLM Load Error: {e}")
llm_model = None
tokenizer = None
# ------------------------------------------------------
# 2. Correction Logic (Context-Aware)
# ------------------------------------------------------
def smart_correction(text):
if not text or not llm_model:
return text
print("--- Starting AI Correction ---")
# 1. Construct the Prompt
# We explicitly tell it to fix OCR errors and maintain structure.
system_prompt = (
"You are an expert OCR post-processing assistant. "
"Your task is to correct OCR errors, typos, and grammar in the provided text. "
"Maintain the original line breaks and layout strictly. "
"Do not add any conversational text. Output ONLY the corrected text."
)
user_prompt = f"Correct the following OCR text:\n\n{text}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# Apply chat template
text_input = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu")
# 2. Run Inference
# Greedy decoding (do_sample=False) is faster and prevents "creative" hallucinations.
try:
generated_ids = llm_model.generate(
model_inputs.input_ids,
max_new_tokens=1024,
temperature=0.1,
do_sample=False
)
# 3. Decode Output
# Strip input tokens to get only the new text
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
except Exception as e:
print(f"Inference Error: {e}")
return text # Fallback to original if AI fails
# ------------------------------------------------------
# 3. Processing Pipeline
# ------------------------------------------------------
def run_ocr(input_image):
tmp_path = None
try:
if input_image is None:
return None, "No image uploaded", None, None
# Temp file for robust loading
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
input_image.save(tmp.name)
tmp_path = tmp.name
# 1. Run OCR
doc = DocumentFile.from_images(tmp_path)
result = ocr_model(doc)
raw_text = result.render()
# 2. Run AI Correction
# The 3B model is fast enough to handle the full page context at once.
corrected_text = smart_correction(raw_text)
# 3. Visualization
image_np = np.array(input_image)
viz_image = image_np.copy()
for page in result.pages:
for block in page.blocks:
for line in block.lines:
for word in line.words:
h, w = viz_image.shape[:2]
(x_min, y_min), (x_max, y_max) = word.geometry
x1, y1 = int(x_min * w), int(y_min * h)
x2, y2 = int(x_max * w), int(y_max * h)
cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
return viz_image, raw_text, corrected_text, result.export()
except Exception as e:
error_log = traceback.format_exc()
return None, f"Error: {e}", f"Logs:\n{error_log}", {"error": str(e)}
finally:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
# ------------------------------------------------------
# 4. Gradio Interface
# ------------------------------------------------------
with gr.Blocks(title="AI OCR with Qwen 3B") as demo:
gr.Markdown("## 📄 Robust AI OCR")
gr.Markdown(f"Using **DocTR** for text extraction and **{MODEL_ID}** for intelligent grammar correction.")
with gr.Row():
input_img = gr.Image(type="pil", label="Upload Document")
with gr.Row():
btn = gr.Button("Run Extraction & Smart Correction", variant="primary")
with gr.Row():
out_img = gr.Image(label="Detections")
with gr.Row():
out_raw = gr.Textbox(label="Raw OCR Output", lines=10)
out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 3B)", lines=10)
with gr.Row():
out_json = gr.JSON(label="JSON Data")
btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_raw, out_corrected, out_json])
if __name__ == "__main__":
demo.launch()