| # import gradio as gr | |
| # import numpy as np | |
| # import cv2 | |
| # import traceback | |
| # import tempfile | |
| # import os | |
| # from PIL import Image | |
| # from doctr.io import DocumentFile | |
| # from doctr.models import ocr_predictor | |
| # # 1. Load the model globally | |
| # print("Loading DocTR model...") | |
| # try: | |
| # # Using a lighter model 'fast_base' to prevent memory crashes on free tier | |
| # # You can switch back to 'db_resnet50' if you have a GPU or more RAM | |
| # model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True) | |
| # except Exception as e: | |
| # print(f"Model Load Error: {e}") | |
| # raise e | |
| # def run_ocr(input_image): | |
| # tmp_path = None | |
| # try: | |
| # if input_image is None: | |
| # return None, "No image uploaded", None | |
| # # 2. ROBUST FIX: Save image to a temporary file first | |
| # # This forces DocTR to read it as a file, which always works. | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| # input_image.save(tmp.name) | |
| # tmp_path = tmp.name | |
| # # 3. Run OCR on the temporary file path | |
| # doc = DocumentFile.from_images(tmp_path) | |
| # result = model(doc) | |
| # # 4. Visualization Prep | |
| # # Convert PIL to numpy for drawing boxes (OpenCV uses BGR, PIL uses RGB) | |
| # image_np = np.array(input_image) | |
| # viz_image = image_np.copy() | |
| # full_text = result.render() | |
| # # 5. Draw Boxes | |
| # for page in result.pages: | |
| # for block in page.blocks: | |
| # for line in block.lines: | |
| # for word in line.words: | |
| # h, w = viz_image.shape[:2] | |
| # (x_min, y_min), (x_max, y_max) = word.geometry | |
| # x1, y1 = int(x_min * w), int(y_min * h) | |
| # x2, y2 = int(x_max * w), int(y_max * h) | |
| # # Draw Green Box | |
| # cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| # return viz_image, full_text, result.export() | |
| # except Exception as e: | |
| # error_log = traceback.format_exc() | |
| # return None, f"❌ ERROR LOG:\n\n{error_log}", {"error": str(e)} | |
| # finally: | |
| # # Cleanup the temp file | |
| # if tmp_path and os.path.exists(tmp_path): | |
| # os.remove(tmp_path) | |
| # # Gradio UI | |
| # with gr.Blocks(title="DocTR OCR Demo") as demo: | |
| # gr.Markdown("## 📄 DocTR OCR (Robust Mode)") | |
| # with gr.Row(): | |
| # input_img = gr.Image(type="pil", label="Upload Document") | |
| # with gr.Row(): | |
| # btn = gr.Button("Run OCR", variant="primary") | |
| # with gr.Row(): | |
| # out_img = gr.Image(label="Detections") | |
| # out_text = gr.Textbox(label="Extracted Text", lines=10) | |
| # out_json = gr.JSON(label="JSON Output") | |
| # btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_text, out_json]) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # import gradio as gr | |
| # import numpy as np | |
| # import cv2 | |
| # import traceback | |
| # import tempfile | |
| # import os | |
| # from PIL import Image | |
| # from doctr.io import DocumentFile | |
| # from doctr.models import ocr_predictor | |
| # from transformers import pipeline | |
| # # ------------------------------------------------------ | |
| # # 1. Load Models Globally | |
| # # ------------------------------------------------------ | |
| # print("⏳ Loading models...") | |
| # # A. Load DocTR (OCR) | |
| # try: | |
| # # 'fast_base' is lightweight for CPU | |
| # ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True) | |
| # print("✅ DocTR loaded.") | |
| # except Exception as e: | |
| # print(f"❌ DocTR Load Error: {e}") | |
| # raise e | |
| # # B. Load Corrector (Small Language Model) | |
| # try: | |
| # # 'google/flan-t5-small' is ~250MB, well under the 1GB limit. | |
| # # We use a text2text-generation pipeline. | |
| # corrector = pipeline( | |
| # "text2text-generation", | |
| # model="google/flan-t5-small", | |
| # device=-1 # -1 forces CPU | |
| # ) | |
| # print("✅ Correction model (Flan-T5-Small) loaded.") | |
| # except Exception as e: | |
| # print(f"❌ Corrector Load Error: {e}") | |
| # corrector = None | |
| # # ------------------------------------------------------ | |
| # # 2. Correction Logic | |
| # # ------------------------------------------------------ | |
| # def smart_correction(text): | |
| # if not text or not text.strip() or corrector is None: | |
| # return text | |
| # # DocTR returns text with newlines. LLMs often prefer line-by-line or chunked input | |
| # # if the context isn't massive. For a small model, processing line-by-line is safer. | |
| # lines = text.split('\n') | |
| # corrected_lines = [] | |
| # print("--- Starting Correction ---") | |
| # for line in lines: | |
| # if len(line.strip()) < 3: # Skip empty/tiny lines | |
| # corrected_lines.append(line) | |
| # continue | |
| # try: | |
| # # Prompt engineering for Flan-T5 | |
| # prompt = f"Fix grammar and OCR errors: {line}" | |
| # # max_length ensures it doesn't ramble. | |
| # result = corrector(prompt, max_length=128) | |
| # fixed_text = result[0]['generated_text'] | |
| # # Fallback: if model returns empty, keep original | |
| # corrected_lines.append(fixed_text if fixed_text else line) | |
| # except Exception as e: | |
| # print(f"Correction failed for line '{line}': {e}") | |
| # corrected_lines.append(line) | |
| # return "\n".join(corrected_lines) | |
| # # ------------------------------------------------------ | |
| # # 3. Main Processing Function | |
| # # ------------------------------------------------------ | |
| # def run_ocr(input_image): | |
| # tmp_path = None | |
| # try: | |
| # if input_image is None: | |
| # return None, "No image uploaded", None, None | |
| # # -- Save temp file for DocTR robustness -- | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| # input_image.save(tmp.name) | |
| # tmp_path = tmp.name | |
| # # -- Run OCR -- | |
| # doc = DocumentFile.from_images(tmp_path) | |
| # result = ocr_model(doc) | |
| # # -- Raw Text -- | |
| # raw_text = result.render() | |
| # # -- Correction Step -- | |
| # corrected_text = smart_correction(raw_text) | |
| # # -- Visualization -- | |
| # image_np = np.array(input_image) | |
| # viz_image = image_np.copy() | |
| # for page in result.pages: | |
| # for block in page.blocks: | |
| # for line in block.lines: | |
| # for word in line.words: | |
| # h, w = viz_image.shape[:2] | |
| # (x_min, y_min), (x_max, y_max) = word.geometry | |
| # x1, y1 = int(x_min * w), int(y_min * h) | |
| # x2, y2 = int(x_max * w), int(y_max * h) | |
| # cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| # return viz_image, raw_text, corrected_text, result.export() | |
| # except Exception as e: | |
| # error_log = traceback.format_exc() | |
| # return None, f"Error: {e}", f"Error Log:\n{error_log}", {"error": str(e)} | |
| # finally: | |
| # if tmp_path and os.path.exists(tmp_path): | |
| # os.remove(tmp_path) | |
| # # ------------------------------------------------------ | |
| # # 4. Gradio UI | |
| # # ------------------------------------------------------ | |
| # with gr.Blocks(title="DocTR OCR + Correction") as demo: | |
| # gr.Markdown("## 📄 AI OCR with Grammar Correction") | |
| # gr.Markdown("Using `DocTR` for extraction and `Flan-T5-Small` for correction.") | |
| # with gr.Row(): | |
| # input_img = gr.Image(type="pil", label="Upload Document") | |
| # with gr.Row(): | |
| # btn = gr.Button("Run Extraction & Correction", variant="primary") | |
| # with gr.Row(): | |
| # out_img = gr.Image(label="Detections") | |
| # with gr.Row(): | |
| # out_raw = gr.Textbox(label="Raw OCR Text", lines=8, placeholder="Raw output appears here...") | |
| # out_corrected = gr.Textbox(label="✨ Corrected Text", lines=8, placeholder="AI corrected output appears here...") | |
| # with gr.Row(): | |
| # out_json = gr.JSON(label="Full JSON Data") | |
| # btn.click( | |
| # fn=run_ocr, | |
| # inputs=input_img, | |
| # outputs=[out_img, out_raw, out_corrected, out_json] | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # import gradio as gr | |
| # import numpy as np | |
| # import cv2 | |
| # import traceback | |
| # import tempfile | |
| # import os | |
| # import torch | |
| # from doctr.io import DocumentFile | |
| # from doctr.models import ocr_predictor | |
| # from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # # ------------------------------------------------------ | |
| # # 1. Configuration & Global Loading | |
| # # ------------------------------------------------------ | |
| # print("⏳ Loading models...") | |
| # # A. Load DocTR (OCR) | |
| # try: | |
| # ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True) | |
| # print("✅ DocTR loaded.") | |
| # except Exception as e: | |
| # print(f"❌ DocTR Load Error: {e}") | |
| # raise e | |
| # # B. Load LLM (Qwen2.5-7B-Instruct) | |
| # # With 50GB RAM, we can load this comfortably. | |
| # # If it is too slow, change MODEL_ID to "Qwen/Qwen2.5-3B-Instruct" or "Qwen/Qwen2.5-1.5B-Instruct" | |
| # MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" | |
| # try: | |
| # print(f"⬇️ Downloading & Loading {MODEL_ID}...") | |
| # tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| # llm_model = AutoModelForCausalLM.from_pretrained( | |
| # MODEL_ID, | |
| # torch_dtype="auto", | |
| # device_map="cpu" # Uses your 50GB System RAM | |
| # ) | |
| # print(f"✅ {MODEL_ID} loaded successfully.") | |
| # except Exception as e: | |
| # print(f"❌ LLM Load Error: {e}") | |
| # llm_model = None | |
| # tokenizer = None | |
| # # ------------------------------------------------------ | |
| # # 2. Correction Logic (The "Smart" Fix) | |
| # # ------------------------------------------------------ | |
| # def smart_correction(text): | |
| # if not text or not llm_model: | |
| # return text | |
| # print("--- Starting AI Correction ---") | |
| # # 1. Construct the Prompt | |
| # # We ask the model to act as a text editor. | |
| # system_prompt = "You are a helpful assistant that corrects OCR text. Fix typos, capitalization, and grammar. Maintain the original line structure. Do not add any conversational text like 'Here is the corrected text'." | |
| # user_prompt = f"Correct the following OCR text:\n\n{text}" | |
| # messages = [ | |
| # {"role": "system", "content": system_prompt}, | |
| # {"role": "user", "content": user_prompt} | |
| # ] | |
| # text_input = tokenizer.apply_chat_template( | |
| # messages, | |
| # tokenize=False, | |
| # add_generation_prompt=True | |
| # ) | |
| # model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu") | |
| # # 2. Run Inference | |
| # # max_new_tokens limits the output length to avoid infinite loops | |
| # generated_ids = llm_model.generate( | |
| # model_inputs.input_ids, | |
| # max_new_tokens=1024, | |
| # temperature=0.1, # Low temp for factual/consistent results | |
| # do_sample=False # Greedy decoding is faster and more deterministic | |
| # ) | |
| # # 3. Decode Output | |
| # # We strip the input tokens to get only the new (corrected) text | |
| # generated_ids = [ | |
| # output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| # ] | |
| # response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| # return response | |
| # # ------------------------------------------------------ | |
| # # 3. Processing Pipeline | |
| # # ------------------------------------------------------ | |
| # def run_ocr(input_image): | |
| # tmp_path = None | |
| # try: | |
| # if input_image is None: | |
| # return None, "No image uploaded", None, None | |
| # # Robust Temp File Handling | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| # input_image.save(tmp.name) | |
| # tmp_path = tmp.name | |
| # # 1. Run OCR | |
| # doc = DocumentFile.from_images(tmp_path) | |
| # result = ocr_model(doc) | |
| # raw_text = result.render() | |
| # # 2. Run AI Correction | |
| # # We pass the WHOLE text block at once. Context helps the AI. | |
| # corrected_text = smart_correction(raw_text) | |
| # # 3. Visualization | |
| # image_np = np.array(input_image) | |
| # viz_image = image_np.copy() | |
| # for page in result.pages: | |
| # for block in page.blocks: | |
| # for line in block.lines: | |
| # for word in line.words: | |
| # h, w = viz_image.shape[:2] | |
| # (x_min, y_min), (x_max, y_max) = word.geometry | |
| # x1, y1 = int(x_min * w), int(y_min * h) | |
| # x2, y2 = int(x_max * w), int(y_max * h) | |
| # cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| # return viz_image, raw_text, corrected_text, result.export() | |
| # except Exception as e: | |
| # error_log = traceback.format_exc() | |
| # return None, f"Error: {e}", f"Logs:\n{error_log}", {"error": str(e)} | |
| # finally: | |
| # if tmp_path and os.path.exists(tmp_path): | |
| # os.remove(tmp_path) | |
| # # ------------------------------------------------------ | |
| # # 4. Gradio Interface | |
| # # ------------------------------------------------------ | |
| # with gr.Blocks(title="Next-Gen OCR") as demo: | |
| # gr.Markdown("## 📄 Next-Gen AI OCR") | |
| # gr.Markdown(f"Using **DocTR** for extraction and **{MODEL_ID}** for smart correction.") | |
| # with gr.Row(): | |
| # input_img = gr.Image(type="pil", label="Upload Document") | |
| # with gr.Row(): | |
| # btn = gr.Button("Run Extraction & Smart Correction", variant="primary") | |
| # with gr.Row(): | |
| # out_img = gr.Image(label="Detections") | |
| # with gr.Row(): | |
| # out_raw = gr.Textbox(label="Raw OCR Output", lines=10) | |
| # out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 7B)", lines=10) | |
| # with gr.Row(): | |
| # out_json = gr.JSON(label="JSON Data") | |
| # btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_raw, out_corrected, out_json]) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import traceback | |
| import tempfile | |
| import os | |
| import torch | |
| from doctr.io import DocumentFile | |
| from doctr.models import ocr_predictor | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # ------------------------------------------------------ | |
| # 1. Configuration & Global Loading | |
| # ------------------------------------------------------ | |
| print("⏳ Loading models...") | |
| # A. Load DocTR (OCR) | |
| try: | |
| ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True) | |
| print("✅ DocTR loaded.") | |
| except Exception as e: | |
| print(f"❌ DocTR Load Error: {e}") | |
| raise e | |
| # B. Load LLM (Qwen2.5-3B-Instruct) | |
| # 3B fits easily in 18GB RAM (takes ~6GB) allowing space for OS + OCR. | |
| MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" | |
| try: | |
| print(f"⬇️ Downloading & Loading {MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype="auto", | |
| device_map="cpu" # Efficiently uses RAM | |
| ) | |
| print(f"✅ {MODEL_ID} loaded successfully.") | |
| except Exception as e: | |
| print(f"❌ LLM Load Error: {e}") | |
| llm_model = None | |
| tokenizer = None | |
| # ------------------------------------------------------ | |
| # 2. Correction Logic (Context-Aware) | |
| # ------------------------------------------------------ | |
| def smart_correction(text): | |
| if not text or not llm_model: | |
| return text | |
| print("--- Starting AI Correction ---") | |
| # 1. Construct the Prompt | |
| # We explicitly tell it to fix OCR errors and maintain structure. | |
| system_prompt = ( | |
| "You are an expert OCR post-processing assistant. " | |
| "Your task is to correct OCR errors, typos, and grammar in the provided text. " | |
| "Maintain the original line breaks and layout strictly. " | |
| "Do not add any conversational text. Output ONLY the corrected text." | |
| ) | |
| user_prompt = f"Correct the following OCR text:\n\n{text}" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| # Apply chat template | |
| text_input = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize | |
| model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu") | |
| # 2. Run Inference | |
| # Greedy decoding (do_sample=False) is faster and prevents "creative" hallucinations. | |
| try: | |
| generated_ids = llm_model.generate( | |
| model_inputs.input_ids, | |
| max_new_tokens=1024, | |
| temperature=0.1, | |
| do_sample=False | |
| ) | |
| # 3. Decode Output | |
| # Strip input tokens to get only the new text | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return response | |
| except Exception as e: | |
| print(f"Inference Error: {e}") | |
| return text # Fallback to original if AI fails | |
| # ------------------------------------------------------ | |
| # 3. Processing Pipeline | |
| # ------------------------------------------------------ | |
| def run_ocr(input_image): | |
| tmp_path = None | |
| try: | |
| if input_image is None: | |
| return None, "No image uploaded", None, None | |
| # Temp file for robust loading | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| input_image.save(tmp.name) | |
| tmp_path = tmp.name | |
| # 1. Run OCR | |
| doc = DocumentFile.from_images(tmp_path) | |
| result = ocr_model(doc) | |
| raw_text = result.render() | |
| # 2. Run AI Correction | |
| # The 3B model is fast enough to handle the full page context at once. | |
| corrected_text = smart_correction(raw_text) | |
| # 3. Visualization | |
| image_np = np.array(input_image) | |
| viz_image = image_np.copy() | |
| for page in result.pages: | |
| for block in page.blocks: | |
| for line in block.lines: | |
| for word in line.words: | |
| h, w = viz_image.shape[:2] | |
| (x_min, y_min), (x_max, y_max) = word.geometry | |
| x1, y1 = int(x_min * w), int(y_min * h) | |
| x2, y2 = int(x_max * w), int(y_max * h) | |
| cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| return viz_image, raw_text, corrected_text, result.export() | |
| except Exception as e: | |
| error_log = traceback.format_exc() | |
| return None, f"Error: {e}", f"Logs:\n{error_log}", {"error": str(e)} | |
| finally: | |
| if tmp_path and os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| # ------------------------------------------------------ | |
| # 4. Gradio Interface | |
| # ------------------------------------------------------ | |
| with gr.Blocks(title="AI OCR with Qwen 3B") as demo: | |
| gr.Markdown("## 📄 Robust AI OCR") | |
| gr.Markdown(f"Using **DocTR** for text extraction and **{MODEL_ID}** for intelligent grammar correction.") | |
| with gr.Row(): | |
| input_img = gr.Image(type="pil", label="Upload Document") | |
| with gr.Row(): | |
| btn = gr.Button("Run Extraction & Smart Correction", variant="primary") | |
| with gr.Row(): | |
| out_img = gr.Image(label="Detections") | |
| with gr.Row(): | |
| out_raw = gr.Textbox(label="Raw OCR Output", lines=10) | |
| out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 3B)", lines=10) | |
| with gr.Row(): | |
| out_json = gr.JSON(label="JSON Data") | |
| btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_raw, out_corrected, out_json]) | |
| if __name__ == "__main__": | |
| demo.launch() |