got-ocr / app.py
iammraat's picture
Update app.py
70bc954 verified
# import gradio as gr
# from transformers import AutoModel, AutoTokenizer
# import torch
# import tempfile
# import os
# import time
# # ------------------------------------------------------
# # 1. Load the CPU-Patched Model
# # ------------------------------------------------------
# # This is the specific repo that fixes the "Found no NVIDIA driver" error.
# MODEL_ID = "srimanth-d/GOT_CPU"
# print(f"⏳ Loading {MODEL_ID}...")
# # Load Tokenizer
# tokenizer = AutoTokenizer.from_pretrained(
# MODEL_ID,
# trust_remote_code=True
# )
# # Load Model
# # low_cpu_mem_usage=True is safe here because this repo is patched for CPU.
# model = AutoModel.from_pretrained(
# MODEL_ID,
# trust_remote_code=True,
# low_cpu_mem_usage=True,
# device_map='cpu',
# use_safetensors=True,
# pad_token_id=tokenizer.eos_token_id
# )
# model = model.eval().float()
# print(f"✅ {MODEL_ID} Loaded! Ready for handwriting.")
# # ------------------------------------------------------
# # 2. The OCR Logic
# # ------------------------------------------------------
# def run_fast_handwriting_ocr(input_image):
# if input_image is None:
# return "No image provided."
# start_time = time.time()
# # Save temp file (Model expects a file path)
# with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
# input_image.save(tmp.name)
# img_path = tmp.name
# try:
# # OCR_TYPE='ocr' tells the model to just read text (no formatting/latex)
# # This is the fastest mode.
# res = model.chat(tokenizer, img_path, ocr_type='ocr')
# elapsed = time.time() - start_time
# return f"{res}\n\n--- ⏱️ Time taken: {elapsed:.2f}s ---"
# except Exception as e:
# return f"Error: {e}"
# finally:
# # Cleanup
# if os.path.exists(img_path):
# os.remove(img_path)
# # ------------------------------------------------------
# # 3. Gradio Interface
# # ------------------------------------------------------
# with gr.Blocks(title="Fast Handwriting OCR") as demo:
# gr.Markdown(f"## ✍️ Fast Handwriting OCR (GOT-OCR2.0)")
# gr.Markdown("A specialized ~600M param model designed to read messy text quickly on CPU.")
# with gr.Row():
# input_img = gr.Image(type="pil", label="Upload Handwritten Note")
# with gr.Row():
# btn = gr.Button("Read Handwriting", variant="primary")
# with gr.Row():
# out_text = gr.Textbox(label="Recognized Text", lines=15)
# btn.click(fn=run_fast_handwriting_ocr, inputs=input_img, outputs=out_text)
# if __name__ == "__main__":
# demo.launch()
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import tempfile
import os
import time
from PIL import Image
# ------------------------------------------------------
# 1. Load the Model (CPU Optimized)
# ------------------------------------------------------
MODEL_ID = "srimanth-d/GOT_CPU"
print(f"⏳ Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModel.from_pretrained(
MODEL_ID,
trust_remote_code=True,
low_cpu_mem_usage=True,
device_map='cpu',
use_safetensors=True,
pad_token_id=tokenizer.eos_token_id
)
model = model.eval().float()
print(f"✅ Model Loaded!")
# ------------------------------------------------------
# 2. Slicing Logic (The Fix)
# ------------------------------------------------------
def process_slice(img_slice, slice_index):
"""Save slice to temp file and run OCR"""
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{slice_index}.jpg") as tmp:
img_slice.save(tmp.name)
slice_path = tmp.name
try:
# OCR_TYPE='ocr' is the fastest mode
res = model.chat(tokenizer, slice_path, ocr_type='ocr')
return res
except Exception as e:
return f"[Error in slice {slice_index}: {e}]"
finally:
if os.path.exists(slice_path):
os.remove(slice_path)
def run_sliced_ocr(input_image):
if input_image is None:
return "No image provided."
start_time = time.time()
w, h = input_image.size
# Heuristic: If image is tall, split it.
# 1024 is the model's native resolution.
full_text = ""
# A. Smart Slicing Strategy
# If the image is a standard document (Height > Width), slice vertically.
if h > 1024:
print(f"--- Slicing Image ({w}x{h}) ---")
# Define 3 overlapping slices to cover a full A4 page nicely
# Top half, Middle (to catch text on the fold), Bottom half
slices = []
# Slice 1: Top 40%
slices.append(input_image.crop((0, 0, w, int(h * 0.40))))
# Slice 2: Middle 40% (overlapping top and bottom)
slices.append(input_image.crop((0, int(h * 0.30), w, int(h * 0.70))))
# Slice 3: Bottom 40%
slices.append(input_image.crop((0, int(h * 0.60), w, h)))
results = []
for i, sl in enumerate(slices):
print(f"Processing slice {i+1}/3...")
txt = process_slice(sl, i)
results.append(txt)
# Join with separators
full_text = "\n--- [Top Section] ---\n" + results[0] + \
"\n--- [Middle Section] ---\n" + results[1] + \
"\n--- [Bottom Section] ---\n" + results[2]
else:
# B. Small Image? Just run once.
print("--- Processing Full Image ---")
full_text = process_slice(input_image, 0)
elapsed = time.time() - start_time
return f"{full_text}\n\n--- ⏱️ Total Time: {elapsed:.2f}s ---"
# ------------------------------------------------------
# 3. Gradio Interface
# ------------------------------------------------------
with gr.Blocks(title="High-Res Handwriting OCR") as demo:
gr.Markdown("## ✍️ Sliced Handwriting OCR")
gr.Markdown("Splits the image into 3 chunks to maintain resolution for messy handwriting.")
with gr.Row():
input_img = gr.Image(type="pil", label="Upload Document")
out_text = gr.Textbox(label="Extracted Text", lines=20)
btn = gr.Button("Run Sliced OCR", variant="primary")
btn.click(fn=run_sliced_ocr, inputs=input_img, outputs=out_text)
if __name__ == "__main__":
demo.launch()