HunyuanOCR / app.py
aal-hawa
edit
f328abe
import gradio as gr
import torch
import tempfile
import os
from PIL import Image
from transformers import AutoProcessor, HunYuanVLForConditionalGeneration
# ============================================================
# HunyuanOCR - Image Text Extraction
# ============================================================
MODEL_ID = "tencent/HunyuanOCR"
model = None
processor = None
def clean_repeated_substrings(text):
n = len(text)
if n < 8000:
return text
for length in range(2, n // 10 + 1):
candidate = text[-length:]
count = 0
i = n - length
while i >= 0 and text[i:i + length] == candidate:
count += 1
i -= length
if count >= 10:
return text[:n - length * (count - 1)]
return text
def load_model():
global model, processor
if model is not None:
return
token = os.getenv("HF_TOKEN", None)
print("Loading HunyuanOCR ...")
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False, token=token)
model = HunYuanVLForConditionalGeneration.from_pretrained(
MODEL_ID,
attn_implementation="eager",
device_map=None,
low_cpu_mem_usage=True,
token=token,
).float()
model.eval()
print("HunyuanOCR loaded.")
def ocr_process(image):
if image is None:
return "Please upload an image."
load_model()
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
image.save(tmp.name)
img_path = tmp.name
try:
messages = [
{
"role": "system",
"content": ""
},
{
"role": "user",
"content": [
{"type": "image", "image": img_path},
{"type": "text", "text": "检测并识别图片中的文字,将文本坐标格式化输出。"}
]
}
]
text_prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_input = Image.open(img_path)
inputs = processor(
text=[text_prompt], images=[image_input],
padding=True, return_tensors="pt"
)
# The processor outputs bfloat16 tensors, but model is float32.
# BatchFeature doesn't support in-place modification well,
# so rebuild as a plain dict with float32 tensors.
clean_inputs = {}
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
if v.dtype == torch.bfloat16:
clean_inputs[k] = v.to(torch.float32)
else:
clean_inputs[k] = v
else:
clean_inputs[k] = v
with torch.no_grad():
generated_ids = model.generate(**clean_inputs, max_new_tokens=16384, do_sample=False)
input_ids = clean_inputs["input_ids"]
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
]
output_text = clean_repeated_substrings(
processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
)
return output_text
finally:
if os.path.exists(img_path):
os.remove(img_path)
# ============================================================
# Gradio Interface
# ============================================================
with gr.Blocks(title="HunyuanOCR") as demo:
gr.Markdown("""
# HunyuanOCR - Text Extraction
Upload an image and the model will detect and extract all text with coordinates.
""")
image_input = gr.Image(type="pil", label="Upload Image")
ocr_output = gr.Textbox(label="Extracted Text", lines=15)
ocr_btn = gr.Button("Extract Text", variant="primary")
ocr_btn.click(ocr_process, image_input, ocr_output)
image_input.change(ocr_process, image_input, ocr_output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")