File size: 6,109 Bytes
f263567 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import os
import time
from threading import Thread
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
Qwen2_5_VLForConditionalGeneration,
TextIteratorStreamer,
)
MODEL_PATHS = {
"Model 3 (structured handwritting)": (
"Emeritus-21/Finetuned-full-HTR-model",
AutoModelForImageTextToText,
),
}
MAX_NEW_TOKENS_DEFAULT = 512
device = "cuda" if torch.cuda.is_available() else "cpu"
# ---------------------------
# Preload models at startup
# ---------------------------
_loaded_processors = {}
_loaded_models = {}
print("π Preloading models into GPU/CPU memory...")
for name, (repo_id, cls) in MODEL_PATHS.items():
try:
print(f"Loading {name} ...")
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
model = cls.from_pretrained(
repo_id,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
_loaded_processors[name] = processor
_loaded_models[name] = model
print(f"β
{name} ready.")
except Exception as e:
print(f"β οΈ Failed to load {name}: {e}")
# ---------------------------
# Warmup (GPU)
# ---------------------------
#@spaces.GPU
def warmup():
try:
default_model_choice = list(MODEL_PATHS.keys())[0]
processor = _loaded_processors[default_model_choice]
model = _loaded_models[default_model_choice]
messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
with torch.inference_mode():
_ = model.generate(**inputs, max_new_tokens=1)
return f"GPU warm and {default_model_choice} ready."
except Exception as e:
return f"Warmup skipped: {e}"
# ---------------------------
# OCR Function (RAW ONLY)
# ---------------------------
#@spaces.GPU
def ocr_image(image: Image.Image, model_choice: str, query: str = None,
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0):
if image is None:
yield "Please upload an image."
return
if model_choice not in _loaded_models:
yield f"Invalid model: {model_choice}"
return
processor = _loaded_processors[model_choice]
model = _loaded_models[model_choice]
if query and query.strip():
prompt = query.strip()
else:
prompt = (
"You are a professional Handwritten OCR system.\n"
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
"- Preserve original structure and line breaks.\n"
"- Keep spacing, bullet points, numbering, and indentation.\n"
"- Render tables as Markdown tables if present.\n"
"- Do NOT autocorrect spelling or grammar.\n"
"- Do NOT merge lines.\n"
"Return RAW transcription only."
)
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}]
chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[chat_prompt], images=[image], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
new_text = new_text.replace("<|im_end|>", "")
buffer += new_text
time.sleep(0.01)
yield buffer
# ---------------------------
# Gradio Interface
# ---------------------------
with gr.Blocks() as demo:
gr.Markdown("## wilson Handwritten OCR ")
model_choice = gr.Radio(
choices=list(MODEL_PATHS.keys()),
value=list(MODEL_PATHS.keys())[0],
label="Select OCR Model"
)
with gr.Tab("πΌ Image Inference"):
query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
image_input = gr.Image(type="pil", label="Upload Handwritten Image")
with gr.Accordion("βοΈ Advanced Options", open=False):
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
with gr.Row():
extract_btn = gr.Button("π€ Extract RAW Text", variant="primary")
clear_btn = gr.Button("π§Ή Clear")
raw_output = gr.Textbox(label="π RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
extract_btn.click(
fn=ocr_image,
inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[raw_output],
api_name="ocr_image" # <--- THIS IS THE CRUCIAL FIX
)
clear_btn.click(
fn=lambda: ("", None, ""),
outputs=[raw_output, image_input, query_input]
)
if __name__ == "__main__":
demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)
|