Spaces:
Running
Running
File size: 4,152 Bytes
5e60c7c 6aefbe1 5e60c7c 6aefbe1 5e60c7c 7f53dc2 6aefbe1 f328abe 5e60c7c f328abe 5e60c7c f328abe 5e60c7c 6aefbe1 5e60c7c 4d706d5 5e60c7c f328abe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import gradio as gr
import torch
import tempfile
import os
from PIL import Image
from transformers import AutoProcessor, HunYuanVLForConditionalGeneration
# ============================================================
# HunyuanOCR - Image Text Extraction
# ============================================================
MODEL_ID = "tencent/HunyuanOCR"
model = None
processor = None
def clean_repeated_substrings(text):
n = len(text)
if n < 8000:
return text
for length in range(2, n // 10 + 1):
candidate = text[-length:]
count = 0
i = n - length
while i >= 0 and text[i:i + length] == candidate:
count += 1
i -= length
if count >= 10:
return text[:n - length * (count - 1)]
return text
def load_model():
global model, processor
if model is not None:
return
token = os.getenv("HF_TOKEN", None)
print("Loading HunyuanOCR ...")
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False, token=token)
model = HunYuanVLForConditionalGeneration.from_pretrained(
MODEL_ID,
attn_implementation="eager",
device_map=None,
low_cpu_mem_usage=True,
token=token,
).float()
model.eval()
print("HunyuanOCR loaded.")
def ocr_process(image):
if image is None:
return "Please upload an image."
load_model()
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
image.save(tmp.name)
img_path = tmp.name
try:
messages = [
{
"role": "system",
"content": ""
},
{
"role": "user",
"content": [
{"type": "image", "image": img_path},
{"type": "text", "text": "检测并识别图片中的文字,将文本坐标格式化输出。"}
]
}
]
text_prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_input = Image.open(img_path)
inputs = processor(
text=[text_prompt], images=[image_input],
padding=True, return_tensors="pt"
)
# The processor outputs bfloat16 tensors, but model is float32.
# BatchFeature doesn't support in-place modification well,
# so rebuild as a plain dict with float32 tensors.
clean_inputs = {}
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
if v.dtype == torch.bfloat16:
clean_inputs[k] = v.to(torch.float32)
else:
clean_inputs[k] = v
else:
clean_inputs[k] = v
with torch.no_grad():
generated_ids = model.generate(**clean_inputs, max_new_tokens=16384, do_sample=False)
input_ids = clean_inputs["input_ids"]
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
]
output_text = clean_repeated_substrings(
processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
)
return output_text
finally:
if os.path.exists(img_path):
os.remove(img_path)
# ============================================================
# Gradio Interface
# ============================================================
with gr.Blocks(title="HunyuanOCR") as demo:
gr.Markdown("""
# HunyuanOCR - Text Extraction
Upload an image and the model will detect and extract all text with coordinates.
""")
image_input = gr.Image(type="pil", label="Upload Image")
ocr_output = gr.Textbox(label="Extracted Text", lines=15)
ocr_btn = gr.Button("Extract Text", variant="primary")
ocr_btn.click(ocr_process, image_input, ocr_output)
image_input.change(ocr_process, image_input, ocr_output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0") |