# import gradio as gr # from transformers import AutoModel, AutoTokenizer # import torch # import tempfile # import os # import time # # ------------------------------------------------------ # # 1. Load the CPU-Patched Model # # ------------------------------------------------------ # # This is the specific repo that fixes the "Found no NVIDIA driver" error. # MODEL_ID = "srimanth-d/GOT_CPU" # print(f"⏳ Loading {MODEL_ID}...") # # Load Tokenizer # tokenizer = AutoTokenizer.from_pretrained( # MODEL_ID, # trust_remote_code=True # ) # # Load Model # # low_cpu_mem_usage=True is safe here because this repo is patched for CPU. # model = AutoModel.from_pretrained( # MODEL_ID, # trust_remote_code=True, # low_cpu_mem_usage=True, # device_map='cpu', # use_safetensors=True, # pad_token_id=tokenizer.eos_token_id # ) # model = model.eval().float() # print(f"✅ {MODEL_ID} Loaded! Ready for handwriting.") # # ------------------------------------------------------ # # 2. The OCR Logic # # ------------------------------------------------------ # def run_fast_handwriting_ocr(input_image): # if input_image is None: # return "No image provided." # start_time = time.time() # # Save temp file (Model expects a file path) # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: # input_image.save(tmp.name) # img_path = tmp.name # try: # # OCR_TYPE='ocr' tells the model to just read text (no formatting/latex) # # This is the fastest mode. # res = model.chat(tokenizer, img_path, ocr_type='ocr') # elapsed = time.time() - start_time # return f"{res}\n\n--- ⏱️ Time taken: {elapsed:.2f}s ---" # except Exception as e: # return f"Error: {e}" # finally: # # Cleanup # if os.path.exists(img_path): # os.remove(img_path) # # ------------------------------------------------------ # # 3. Gradio Interface # # ------------------------------------------------------ # with gr.Blocks(title="Fast Handwriting OCR") as demo: # gr.Markdown(f"## ✍️ Fast Handwriting OCR (GOT-OCR2.0)") # gr.Markdown("A specialized ~600M param model designed to read messy text quickly on CPU.") # with gr.Row(): # input_img = gr.Image(type="pil", label="Upload Handwritten Note") # with gr.Row(): # btn = gr.Button("Read Handwriting", variant="primary") # with gr.Row(): # out_text = gr.Textbox(label="Recognized Text", lines=15) # btn.click(fn=run_fast_handwriting_ocr, inputs=input_img, outputs=out_text) # if __name__ == "__main__": # demo.launch() import gradio as gr from transformers import AutoModel, AutoTokenizer import torch import tempfile import os import time from PIL import Image # ------------------------------------------------------ # 1. Load the Model (CPU Optimized) # ------------------------------------------------------ MODEL_ID = "srimanth-d/GOT_CPU" print(f"⏳ Loading {MODEL_ID}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModel.from_pretrained( MODEL_ID, trust_remote_code=True, low_cpu_mem_usage=True, device_map='cpu', use_safetensors=True, pad_token_id=tokenizer.eos_token_id ) model = model.eval().float() print(f"✅ Model Loaded!") # ------------------------------------------------------ # 2. Slicing Logic (The Fix) # ------------------------------------------------------ def process_slice(img_slice, slice_index): """Save slice to temp file and run OCR""" with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{slice_index}.jpg") as tmp: img_slice.save(tmp.name) slice_path = tmp.name try: # OCR_TYPE='ocr' is the fastest mode res = model.chat(tokenizer, slice_path, ocr_type='ocr') return res except Exception as e: return f"[Error in slice {slice_index}: {e}]" finally: if os.path.exists(slice_path): os.remove(slice_path) def run_sliced_ocr(input_image): if input_image is None: return "No image provided." start_time = time.time() w, h = input_image.size # Heuristic: If image is tall, split it. # 1024 is the model's native resolution. full_text = "" # A. Smart Slicing Strategy # If the image is a standard document (Height > Width), slice vertically. if h > 1024: print(f"--- Slicing Image ({w}x{h}) ---") # Define 3 overlapping slices to cover a full A4 page nicely # Top half, Middle (to catch text on the fold), Bottom half slices = [] # Slice 1: Top 40% slices.append(input_image.crop((0, 0, w, int(h * 0.40)))) # Slice 2: Middle 40% (overlapping top and bottom) slices.append(input_image.crop((0, int(h * 0.30), w, int(h * 0.70)))) # Slice 3: Bottom 40% slices.append(input_image.crop((0, int(h * 0.60), w, h))) results = [] for i, sl in enumerate(slices): print(f"Processing slice {i+1}/3...") txt = process_slice(sl, i) results.append(txt) # Join with separators full_text = "\n--- [Top Section] ---\n" + results[0] + \ "\n--- [Middle Section] ---\n" + results[1] + \ "\n--- [Bottom Section] ---\n" + results[2] else: # B. Small Image? Just run once. print("--- Processing Full Image ---") full_text = process_slice(input_image, 0) elapsed = time.time() - start_time return f"{full_text}\n\n--- ⏱️ Total Time: {elapsed:.2f}s ---" # ------------------------------------------------------ # 3. Gradio Interface # ------------------------------------------------------ with gr.Blocks(title="High-Res Handwriting OCR") as demo: gr.Markdown("## ✍️ Sliced Handwriting OCR") gr.Markdown("Splits the image into 3 chunks to maintain resolution for messy handwriting.") with gr.Row(): input_img = gr.Image(type="pil", label="Upload Document") out_text = gr.Textbox(label="Extracted Text", lines=20) btn = gr.Button("Run Sliced OCR", variant="primary") btn.click(fn=run_sliced_ocr, inputs=input_img, outputs=out_text) if __name__ == "__main__": demo.launch()