| # import gradio as gr | |
| # from transformers import AutoModel, AutoTokenizer | |
| # import torch | |
| # import tempfile | |
| # import os | |
| # import time | |
| # # ------------------------------------------------------ | |
| # # 1. Load the CPU-Patched Model | |
| # # ------------------------------------------------------ | |
| # # This is the specific repo that fixes the "Found no NVIDIA driver" error. | |
| # MODEL_ID = "srimanth-d/GOT_CPU" | |
| # print(f"⏳ Loading {MODEL_ID}...") | |
| # # Load Tokenizer | |
| # tokenizer = AutoTokenizer.from_pretrained( | |
| # MODEL_ID, | |
| # trust_remote_code=True | |
| # ) | |
| # # Load Model | |
| # # low_cpu_mem_usage=True is safe here because this repo is patched for CPU. | |
| # model = AutoModel.from_pretrained( | |
| # MODEL_ID, | |
| # trust_remote_code=True, | |
| # low_cpu_mem_usage=True, | |
| # device_map='cpu', | |
| # use_safetensors=True, | |
| # pad_token_id=tokenizer.eos_token_id | |
| # ) | |
| # model = model.eval().float() | |
| # print(f"✅ {MODEL_ID} Loaded! Ready for handwriting.") | |
| # # ------------------------------------------------------ | |
| # # 2. The OCR Logic | |
| # # ------------------------------------------------------ | |
| # def run_fast_handwriting_ocr(input_image): | |
| # if input_image is None: | |
| # return "No image provided." | |
| # start_time = time.time() | |
| # # Save temp file (Model expects a file path) | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| # input_image.save(tmp.name) | |
| # img_path = tmp.name | |
| # try: | |
| # # OCR_TYPE='ocr' tells the model to just read text (no formatting/latex) | |
| # # This is the fastest mode. | |
| # res = model.chat(tokenizer, img_path, ocr_type='ocr') | |
| # elapsed = time.time() - start_time | |
| # return f"{res}\n\n--- ⏱️ Time taken: {elapsed:.2f}s ---" | |
| # except Exception as e: | |
| # return f"Error: {e}" | |
| # finally: | |
| # # Cleanup | |
| # if os.path.exists(img_path): | |
| # os.remove(img_path) | |
| # # ------------------------------------------------------ | |
| # # 3. Gradio Interface | |
| # # ------------------------------------------------------ | |
| # with gr.Blocks(title="Fast Handwriting OCR") as demo: | |
| # gr.Markdown(f"## ✍️ Fast Handwriting OCR (GOT-OCR2.0)") | |
| # gr.Markdown("A specialized ~600M param model designed to read messy text quickly on CPU.") | |
| # with gr.Row(): | |
| # input_img = gr.Image(type="pil", label="Upload Handwritten Note") | |
| # with gr.Row(): | |
| # btn = gr.Button("Read Handwriting", variant="primary") | |
| # with gr.Row(): | |
| # out_text = gr.Textbox(label="Recognized Text", lines=15) | |
| # btn.click(fn=run_fast_handwriting_ocr, inputs=input_img, outputs=out_text) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| import tempfile | |
| import os | |
| import time | |
| from PIL import Image | |
| # ------------------------------------------------------ | |
| # 1. Load the Model (CPU Optimized) | |
| # ------------------------------------------------------ | |
| MODEL_ID = "srimanth-d/GOT_CPU" | |
| print(f"⏳ Loading {MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModel.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| device_map='cpu', | |
| use_safetensors=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| model = model.eval().float() | |
| print(f"✅ Model Loaded!") | |
| # ------------------------------------------------------ | |
| # 2. Slicing Logic (The Fix) | |
| # ------------------------------------------------------ | |
| def process_slice(img_slice, slice_index): | |
| """Save slice to temp file and run OCR""" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{slice_index}.jpg") as tmp: | |
| img_slice.save(tmp.name) | |
| slice_path = tmp.name | |
| try: | |
| # OCR_TYPE='ocr' is the fastest mode | |
| res = model.chat(tokenizer, slice_path, ocr_type='ocr') | |
| return res | |
| except Exception as e: | |
| return f"[Error in slice {slice_index}: {e}]" | |
| finally: | |
| if os.path.exists(slice_path): | |
| os.remove(slice_path) | |
| def run_sliced_ocr(input_image): | |
| if input_image is None: | |
| return "No image provided." | |
| start_time = time.time() | |
| w, h = input_image.size | |
| # Heuristic: If image is tall, split it. | |
| # 1024 is the model's native resolution. | |
| full_text = "" | |
| # A. Smart Slicing Strategy | |
| # If the image is a standard document (Height > Width), slice vertically. | |
| if h > 1024: | |
| print(f"--- Slicing Image ({w}x{h}) ---") | |
| # Define 3 overlapping slices to cover a full A4 page nicely | |
| # Top half, Middle (to catch text on the fold), Bottom half | |
| slices = [] | |
| # Slice 1: Top 40% | |
| slices.append(input_image.crop((0, 0, w, int(h * 0.40)))) | |
| # Slice 2: Middle 40% (overlapping top and bottom) | |
| slices.append(input_image.crop((0, int(h * 0.30), w, int(h * 0.70)))) | |
| # Slice 3: Bottom 40% | |
| slices.append(input_image.crop((0, int(h * 0.60), w, h))) | |
| results = [] | |
| for i, sl in enumerate(slices): | |
| print(f"Processing slice {i+1}/3...") | |
| txt = process_slice(sl, i) | |
| results.append(txt) | |
| # Join with separators | |
| full_text = "\n--- [Top Section] ---\n" + results[0] + \ | |
| "\n--- [Middle Section] ---\n" + results[1] + \ | |
| "\n--- [Bottom Section] ---\n" + results[2] | |
| else: | |
| # B. Small Image? Just run once. | |
| print("--- Processing Full Image ---") | |
| full_text = process_slice(input_image, 0) | |
| elapsed = time.time() - start_time | |
| return f"{full_text}\n\n--- ⏱️ Total Time: {elapsed:.2f}s ---" | |
| # ------------------------------------------------------ | |
| # 3. Gradio Interface | |
| # ------------------------------------------------------ | |
| with gr.Blocks(title="High-Res Handwriting OCR") as demo: | |
| gr.Markdown("## ✍️ Sliced Handwriting OCR") | |
| gr.Markdown("Splits the image into 3 chunks to maintain resolution for messy handwriting.") | |
| with gr.Row(): | |
| input_img = gr.Image(type="pil", label="Upload Document") | |
| out_text = gr.Textbox(label="Extracted Text", lines=20) | |
| btn = gr.Button("Run Sliced OCR", variant="primary") | |
| btn.click(fn=run_sliced_ocr, inputs=input_img, outputs=out_text) | |
| if __name__ == "__main__": | |
| demo.launch() |