|
|
import os |
|
|
import time |
|
|
from threading import Thread |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoProcessor, |
|
|
AutoModelForImageTextToText, |
|
|
Qwen2_5_VLForConditionalGeneration, |
|
|
TextIteratorStreamer, |
|
|
) |
|
|
MODEL_PATHS = { |
|
|
"Model 3 (structured handwritting)": ( |
|
|
"Emeritus-21/Finetuned-full-HTR-model", |
|
|
AutoModelForImageTextToText, |
|
|
), |
|
|
} |
|
|
|
|
|
MAX_NEW_TOKENS_DEFAULT = 512 |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_loaded_processors = {} |
|
|
_loaded_models = {} |
|
|
|
|
|
print("π Preloading models into GPU/CPU memory...") |
|
|
|
|
|
for name, (repo_id, cls) in MODEL_PATHS.items(): |
|
|
try: |
|
|
print(f"Loading {name} ...") |
|
|
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) |
|
|
model = cls.from_pretrained( |
|
|
repo_id, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 |
|
|
).to(device).eval() |
|
|
_loaded_processors[name] = processor |
|
|
_loaded_models[name] = model |
|
|
print(f"β
{name} ready.") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Failed to load {name}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def warmup(): |
|
|
try: |
|
|
default_model_choice = list(MODEL_PATHS.keys())[0] |
|
|
processor = _loaded_processors[default_model_choice] |
|
|
model = _loaded_models[default_model_choice] |
|
|
|
|
|
messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}] |
|
|
chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
_ = model.generate(**inputs, max_new_tokens=1) |
|
|
|
|
|
return f"GPU warm and {default_model_choice} ready." |
|
|
except Exception as e: |
|
|
return f"Warmup skipped: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ocr_image(image: Image.Image, model_choice: str, query: str = None, |
|
|
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, |
|
|
temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0): |
|
|
|
|
|
if image is None: |
|
|
yield "Please upload an image." |
|
|
return |
|
|
|
|
|
if model_choice not in _loaded_models: |
|
|
yield f"Invalid model: {model_choice}" |
|
|
return |
|
|
|
|
|
processor = _loaded_processors[model_choice] |
|
|
model = _loaded_models[model_choice] |
|
|
|
|
|
if query and query.strip(): |
|
|
prompt = query.strip() |
|
|
else: |
|
|
prompt = ( |
|
|
"You are a professional Handwritten OCR system.\n" |
|
|
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n" |
|
|
"- Preserve original structure and line breaks.\n" |
|
|
"- Keep spacing, bullet points, numbering, and indentation.\n" |
|
|
"- Render tables as Markdown tables if present.\n" |
|
|
"- Do NOT autocorrect spelling or grammar.\n" |
|
|
"- Do NOT merge lines.\n" |
|
|
"Return RAW transcription only." |
|
|
) |
|
|
|
|
|
messages = [{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image}, |
|
|
{"type": "text", "text": prompt} |
|
|
] |
|
|
}] |
|
|
|
|
|
chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = processor(text=[chat_prompt], images=[image], return_tensors="pt").to(device) |
|
|
|
|
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
generation_kwargs = dict( |
|
|
**inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=False, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
repetition_penalty=repetition_penalty |
|
|
) |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
buffer = "" |
|
|
for new_text in streamer: |
|
|
new_text = new_text.replace("<|im_end|>", "") |
|
|
buffer += new_text |
|
|
time.sleep(0.01) |
|
|
yield buffer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## wilson Handwritten OCR ") |
|
|
|
|
|
model_choice = gr.Radio( |
|
|
choices=list(MODEL_PATHS.keys()), |
|
|
value=list(MODEL_PATHS.keys())[0], |
|
|
label="Select OCR Model" |
|
|
) |
|
|
|
|
|
with gr.Tab("πΌ Image Inference"): |
|
|
query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output") |
|
|
image_input = gr.Image(type="pil", label="Upload Handwritten Image") |
|
|
|
|
|
with gr.Accordion("βοΈ Advanced Options", open=False): |
|
|
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens") |
|
|
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature") |
|
|
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)") |
|
|
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k") |
|
|
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty") |
|
|
|
|
|
with gr.Row(): |
|
|
extract_btn = gr.Button("π€ Extract RAW Text", variant="primary") |
|
|
clear_btn = gr.Button("π§Ή Clear") |
|
|
|
|
|
raw_output = gr.Textbox(label="π RAW Structured Output (exact as written)", lines=18, show_copy_button=True) |
|
|
|
|
|
extract_btn.click( |
|
|
fn=ocr_image, |
|
|
inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], |
|
|
outputs=[raw_output], |
|
|
api_name="ocr_image" |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=lambda: ("", None, ""), |
|
|
outputs=[raw_output, image_input, query_input] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True) |
|
|
|