import gradio as gr import torch import tempfile import os from PIL import Image from transformers import AutoProcessor, HunYuanVLForConditionalGeneration # ============================================================ # HunyuanOCR - Image Text Extraction # ============================================================ MODEL_ID = "tencent/HunyuanOCR" model = None processor = None def clean_repeated_substrings(text): n = len(text) if n < 8000: return text for length in range(2, n // 10 + 1): candidate = text[-length:] count = 0 i = n - length while i >= 0 and text[i:i + length] == candidate: count += 1 i -= length if count >= 10: return text[:n - length * (count - 1)] return text def load_model(): global model, processor if model is not None: return token = os.getenv("HF_TOKEN", None) print("Loading HunyuanOCR ...") processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False, token=token) model = HunYuanVLForConditionalGeneration.from_pretrained( MODEL_ID, attn_implementation="eager", device_map=None, low_cpu_mem_usage=True, token=token, ).float() model.eval() print("HunyuanOCR loaded.") def ocr_process(image): if image is None: return "Please upload an image." load_model() with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: image.save(tmp.name) img_path = tmp.name try: messages = [ { "role": "system", "content": "" }, { "role": "user", "content": [ {"type": "image", "image": img_path}, {"type": "text", "text": "检测并识别图片中的文字,将文本坐标格式化输出。"} ] } ] text_prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_input = Image.open(img_path) inputs = processor( text=[text_prompt], images=[image_input], padding=True, return_tensors="pt" ) # The processor outputs bfloat16 tensors, but model is float32. # BatchFeature doesn't support in-place modification well, # so rebuild as a plain dict with float32 tensors. clean_inputs = {} for k, v in inputs.items(): if isinstance(v, torch.Tensor): if v.dtype == torch.bfloat16: clean_inputs[k] = v.to(torch.float32) else: clean_inputs[k] = v else: clean_inputs[k] = v with torch.no_grad(): generated_ids = model.generate(**clean_inputs, max_new_tokens=16384, do_sample=False) input_ids = clean_inputs["input_ids"] generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids) ] output_text = clean_repeated_substrings( processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] ) return output_text finally: if os.path.exists(img_path): os.remove(img_path) # ============================================================ # Gradio Interface # ============================================================ with gr.Blocks(title="HunyuanOCR") as demo: gr.Markdown(""" # HunyuanOCR - Text Extraction Upload an image and the model will detect and extract all text with coordinates. """) image_input = gr.Image(type="pil", label="Upload Image") ocr_output = gr.Textbox(label="Extracted Text", lines=15) ocr_btn = gr.Button("Extract Text", variant="primary") ocr_btn.click(ocr_process, image_input, ocr_output) image_input.change(ocr_process, image_input, ocr_output) if __name__ == "__main__": demo.launch(server_name="0.0.0.0")