# import gradio as gr # import numpy as np # import cv2 # import traceback # import tempfile # import os # from PIL import Image # from doctr.io import DocumentFile # from doctr.models import ocr_predictor # # 1. Load the model globally # print("Loading DocTR model...") # try: # # Using a lighter model 'fast_base' to prevent memory crashes on free tier # # You can switch back to 'db_resnet50' if you have a GPU or more RAM # model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True) # except Exception as e: # print(f"Model Load Error: {e}") # raise e # def run_ocr(input_image): # tmp_path = None # try: # if input_image is None: # return None, "No image uploaded", None # # 2. ROBUST FIX: Save image to a temporary file first # # This forces DocTR to read it as a file, which always works. # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: # input_image.save(tmp.name) # tmp_path = tmp.name # # 3. Run OCR on the temporary file path # doc = DocumentFile.from_images(tmp_path) # result = model(doc) # # 4. Visualization Prep # # Convert PIL to numpy for drawing boxes (OpenCV uses BGR, PIL uses RGB) # image_np = np.array(input_image) # viz_image = image_np.copy() # full_text = result.render() # # 5. Draw Boxes # for page in result.pages: # for block in page.blocks: # for line in block.lines: # for word in line.words: # h, w = viz_image.shape[:2] # (x_min, y_min), (x_max, y_max) = word.geometry # x1, y1 = int(x_min * w), int(y_min * h) # x2, y2 = int(x_max * w), int(y_max * h) # # Draw Green Box # cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # return viz_image, full_text, result.export() # except Exception as e: # error_log = traceback.format_exc() # return None, f"❌ ERROR LOG:\n\n{error_log}", {"error": str(e)} # finally: # # Cleanup the temp file # if tmp_path and os.path.exists(tmp_path): # os.remove(tmp_path) # # Gradio UI # with gr.Blocks(title="DocTR OCR Demo") as demo: # gr.Markdown("## 📄 DocTR OCR (Robust Mode)") # with gr.Row(): # input_img = gr.Image(type="pil", label="Upload Document") # with gr.Row(): # btn = gr.Button("Run OCR", variant="primary") # with gr.Row(): # out_img = gr.Image(label="Detections") # out_text = gr.Textbox(label="Extracted Text", lines=10) # out_json = gr.JSON(label="JSON Output") # btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_text, out_json]) # if __name__ == "__main__": # demo.launch() # import gradio as gr # import numpy as np # import cv2 # import traceback # import tempfile # import os # from PIL import Image # from doctr.io import DocumentFile # from doctr.models import ocr_predictor # from transformers import pipeline # # ------------------------------------------------------ # # 1. Load Models Globally # # ------------------------------------------------------ # print("⏳ Loading models...") # # A. Load DocTR (OCR) # try: # # 'fast_base' is lightweight for CPU # ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True) # print("✅ DocTR loaded.") # except Exception as e: # print(f"❌ DocTR Load Error: {e}") # raise e # # B. Load Corrector (Small Language Model) # try: # # 'google/flan-t5-small' is ~250MB, well under the 1GB limit. # # We use a text2text-generation pipeline. # corrector = pipeline( # "text2text-generation", # model="google/flan-t5-small", # device=-1 # -1 forces CPU # ) # print("✅ Correction model (Flan-T5-Small) loaded.") # except Exception as e: # print(f"❌ Corrector Load Error: {e}") # corrector = None # # ------------------------------------------------------ # # 2. Correction Logic # # ------------------------------------------------------ # def smart_correction(text): # if not text or not text.strip() or corrector is None: # return text # # DocTR returns text with newlines. LLMs often prefer line-by-line or chunked input # # if the context isn't massive. For a small model, processing line-by-line is safer. # lines = text.split('\n') # corrected_lines = [] # print("--- Starting Correction ---") # for line in lines: # if len(line.strip()) < 3: # Skip empty/tiny lines # corrected_lines.append(line) # continue # try: # # Prompt engineering for Flan-T5 # prompt = f"Fix grammar and OCR errors: {line}" # # max_length ensures it doesn't ramble. # result = corrector(prompt, max_length=128) # fixed_text = result[0]['generated_text'] # # Fallback: if model returns empty, keep original # corrected_lines.append(fixed_text if fixed_text else line) # except Exception as e: # print(f"Correction failed for line '{line}': {e}") # corrected_lines.append(line) # return "\n".join(corrected_lines) # # ------------------------------------------------------ # # 3. Main Processing Function # # ------------------------------------------------------ # def run_ocr(input_image): # tmp_path = None # try: # if input_image is None: # return None, "No image uploaded", None, None # # -- Save temp file for DocTR robustness -- # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: # input_image.save(tmp.name) # tmp_path = tmp.name # # -- Run OCR -- # doc = DocumentFile.from_images(tmp_path) # result = ocr_model(doc) # # -- Raw Text -- # raw_text = result.render() # # -- Correction Step -- # corrected_text = smart_correction(raw_text) # # -- Visualization -- # image_np = np.array(input_image) # viz_image = image_np.copy() # for page in result.pages: # for block in page.blocks: # for line in block.lines: # for word in line.words: # h, w = viz_image.shape[:2] # (x_min, y_min), (x_max, y_max) = word.geometry # x1, y1 = int(x_min * w), int(y_min * h) # x2, y2 = int(x_max * w), int(y_max * h) # cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # return viz_image, raw_text, corrected_text, result.export() # except Exception as e: # error_log = traceback.format_exc() # return None, f"Error: {e}", f"Error Log:\n{error_log}", {"error": str(e)} # finally: # if tmp_path and os.path.exists(tmp_path): # os.remove(tmp_path) # # ------------------------------------------------------ # # 4. Gradio UI # # ------------------------------------------------------ # with gr.Blocks(title="DocTR OCR + Correction") as demo: # gr.Markdown("## 📄 AI OCR with Grammar Correction") # gr.Markdown("Using `DocTR` for extraction and `Flan-T5-Small` for correction.") # with gr.Row(): # input_img = gr.Image(type="pil", label="Upload Document") # with gr.Row(): # btn = gr.Button("Run Extraction & Correction", variant="primary") # with gr.Row(): # out_img = gr.Image(label="Detections") # with gr.Row(): # out_raw = gr.Textbox(label="Raw OCR Text", lines=8, placeholder="Raw output appears here...") # out_corrected = gr.Textbox(label="✨ Corrected Text", lines=8, placeholder="AI corrected output appears here...") # with gr.Row(): # out_json = gr.JSON(label="Full JSON Data") # btn.click( # fn=run_ocr, # inputs=input_img, # outputs=[out_img, out_raw, out_corrected, out_json] # ) # if __name__ == "__main__": # demo.launch() # import gradio as gr # import numpy as np # import cv2 # import traceback # import tempfile # import os # import torch # from doctr.io import DocumentFile # from doctr.models import ocr_predictor # from transformers import AutoModelForCausalLM, AutoTokenizer # # ------------------------------------------------------ # # 1. Configuration & Global Loading # # ------------------------------------------------------ # print("⏳ Loading models...") # # A. Load DocTR (OCR) # try: # ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True) # print("✅ DocTR loaded.") # except Exception as e: # print(f"❌ DocTR Load Error: {e}") # raise e # # B. Load LLM (Qwen2.5-7B-Instruct) # # With 50GB RAM, we can load this comfortably. # # If it is too slow, change MODEL_ID to "Qwen/Qwen2.5-3B-Instruct" or "Qwen/Qwen2.5-1.5B-Instruct" # MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" # try: # print(f"⬇️ Downloading & Loading {MODEL_ID}...") # tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # llm_model = AutoModelForCausalLM.from_pretrained( # MODEL_ID, # torch_dtype="auto", # device_map="cpu" # Uses your 50GB System RAM # ) # print(f"✅ {MODEL_ID} loaded successfully.") # except Exception as e: # print(f"❌ LLM Load Error: {e}") # llm_model = None # tokenizer = None # # ------------------------------------------------------ # # 2. Correction Logic (The "Smart" Fix) # # ------------------------------------------------------ # def smart_correction(text): # if not text or not llm_model: # return text # print("--- Starting AI Correction ---") # # 1. Construct the Prompt # # We ask the model to act as a text editor. # system_prompt = "You are a helpful assistant that corrects OCR text. Fix typos, capitalization, and grammar. Maintain the original line structure. Do not add any conversational text like 'Here is the corrected text'." # user_prompt = f"Correct the following OCR text:\n\n{text}" # messages = [ # {"role": "system", "content": system_prompt}, # {"role": "user", "content": user_prompt} # ] # text_input = tokenizer.apply_chat_template( # messages, # tokenize=False, # add_generation_prompt=True # ) # model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu") # # 2. Run Inference # # max_new_tokens limits the output length to avoid infinite loops # generated_ids = llm_model.generate( # model_inputs.input_ids, # max_new_tokens=1024, # temperature=0.1, # Low temp for factual/consistent results # do_sample=False # Greedy decoding is faster and more deterministic # ) # # 3. Decode Output # # We strip the input tokens to get only the new (corrected) text # generated_ids = [ # output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) # ] # response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # return response # # ------------------------------------------------------ # # 3. Processing Pipeline # # ------------------------------------------------------ # def run_ocr(input_image): # tmp_path = None # try: # if input_image is None: # return None, "No image uploaded", None, None # # Robust Temp File Handling # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: # input_image.save(tmp.name) # tmp_path = tmp.name # # 1. Run OCR # doc = DocumentFile.from_images(tmp_path) # result = ocr_model(doc) # raw_text = result.render() # # 2. Run AI Correction # # We pass the WHOLE text block at once. Context helps the AI. # corrected_text = smart_correction(raw_text) # # 3. Visualization # image_np = np.array(input_image) # viz_image = image_np.copy() # for page in result.pages: # for block in page.blocks: # for line in block.lines: # for word in line.words: # h, w = viz_image.shape[:2] # (x_min, y_min), (x_max, y_max) = word.geometry # x1, y1 = int(x_min * w), int(y_min * h) # x2, y2 = int(x_max * w), int(y_max * h) # cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # return viz_image, raw_text, corrected_text, result.export() # except Exception as e: # error_log = traceback.format_exc() # return None, f"Error: {e}", f"Logs:\n{error_log}", {"error": str(e)} # finally: # if tmp_path and os.path.exists(tmp_path): # os.remove(tmp_path) # # ------------------------------------------------------ # # 4. Gradio Interface # # ------------------------------------------------------ # with gr.Blocks(title="Next-Gen OCR") as demo: # gr.Markdown("## 📄 Next-Gen AI OCR") # gr.Markdown(f"Using **DocTR** for extraction and **{MODEL_ID}** for smart correction.") # with gr.Row(): # input_img = gr.Image(type="pil", label="Upload Document") # with gr.Row(): # btn = gr.Button("Run Extraction & Smart Correction", variant="primary") # with gr.Row(): # out_img = gr.Image(label="Detections") # with gr.Row(): # out_raw = gr.Textbox(label="Raw OCR Output", lines=10) # out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 7B)", lines=10) # with gr.Row(): # out_json = gr.JSON(label="JSON Data") # btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_raw, out_corrected, out_json]) # if __name__ == "__main__": # demo.launch() import gradio as gr import numpy as np import cv2 import traceback import tempfile import os import torch from doctr.io import DocumentFile from doctr.models import ocr_predictor from transformers import AutoModelForCausalLM, AutoTokenizer # ------------------------------------------------------ # 1. Configuration & Global Loading # ------------------------------------------------------ print("⏳ Loading models...") # A. Load DocTR (OCR) try: ocr_model = ocr_predictor(det_arch='fast_base', reco_arch='crnn_vgg16_bn', pretrained=True) print("✅ DocTR loaded.") except Exception as e: print(f"❌ DocTR Load Error: {e}") raise e # B. Load LLM (Qwen2.5-3B-Instruct) # 3B fits easily in 18GB RAM (takes ~6GB) allowing space for OS + OCR. MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" try: print(f"⬇️ Downloading & Loading {MODEL_ID}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) llm_model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", device_map="cpu" # Efficiently uses RAM ) print(f"✅ {MODEL_ID} loaded successfully.") except Exception as e: print(f"❌ LLM Load Error: {e}") llm_model = None tokenizer = None # ------------------------------------------------------ # 2. Correction Logic (Context-Aware) # ------------------------------------------------------ def smart_correction(text): if not text or not llm_model: return text print("--- Starting AI Correction ---") # 1. Construct the Prompt # We explicitly tell it to fix OCR errors and maintain structure. system_prompt = ( "You are an expert OCR post-processing assistant. " "Your task is to correct OCR errors, typos, and grammar in the provided text. " "Maintain the original line breaks and layout strictly. " "Do not add any conversational text. Output ONLY the corrected text." ) user_prompt = f"Correct the following OCR text:\n\n{text}" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] # Apply chat template text_input = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize model_inputs = tokenizer([text_input], return_tensors="pt").to("cpu") # 2. Run Inference # Greedy decoding (do_sample=False) is faster and prevents "creative" hallucinations. try: generated_ids = llm_model.generate( model_inputs.input_ids, max_new_tokens=1024, temperature=0.1, do_sample=False ) # 3. Decode Output # Strip input tokens to get only the new text generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return response except Exception as e: print(f"Inference Error: {e}") return text # Fallback to original if AI fails # ------------------------------------------------------ # 3. Processing Pipeline # ------------------------------------------------------ def run_ocr(input_image): tmp_path = None try: if input_image is None: return None, "No image uploaded", None, None # Temp file for robust loading with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: input_image.save(tmp.name) tmp_path = tmp.name # 1. Run OCR doc = DocumentFile.from_images(tmp_path) result = ocr_model(doc) raw_text = result.render() # 2. Run AI Correction # The 3B model is fast enough to handle the full page context at once. corrected_text = smart_correction(raw_text) # 3. Visualization image_np = np.array(input_image) viz_image = image_np.copy() for page in result.pages: for block in page.blocks: for line in block.lines: for word in line.words: h, w = viz_image.shape[:2] (x_min, y_min), (x_max, y_max) = word.geometry x1, y1 = int(x_min * w), int(y_min * h) x2, y2 = int(x_max * w), int(y_max * h) cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) return viz_image, raw_text, corrected_text, result.export() except Exception as e: error_log = traceback.format_exc() return None, f"Error: {e}", f"Logs:\n{error_log}", {"error": str(e)} finally: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) # ------------------------------------------------------ # 4. Gradio Interface # ------------------------------------------------------ with gr.Blocks(title="AI OCR with Qwen 3B") as demo: gr.Markdown("## 📄 Robust AI OCR") gr.Markdown(f"Using **DocTR** for text extraction and **{MODEL_ID}** for intelligent grammar correction.") with gr.Row(): input_img = gr.Image(type="pil", label="Upload Document") with gr.Row(): btn = gr.Button("Run Extraction & Smart Correction", variant="primary") with gr.Row(): out_img = gr.Image(label="Detections") with gr.Row(): out_raw = gr.Textbox(label="Raw OCR Output", lines=10) out_corrected = gr.Textbox(label="🤖 AI Corrected (Qwen 3B)", lines=10) with gr.Row(): out_json = gr.JSON(label="JSON Data") btn.click(fn=run_ocr, inputs=input_img, outputs=[out_img, out_raw, out_corrected, out_json]) if __name__ == "__main__": demo.launch()