File size: 4,152 Bytes
5e60c7c
 
 
 
 
 
 
 
6aefbe1
5e60c7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aefbe1
5e60c7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f53dc2
 
6aefbe1
f328abe
 
 
 
 
 
 
 
 
 
 
5e60c7c
 
f328abe
5e60c7c
f328abe
5e60c7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aefbe1
5e60c7c
 
 
 
4d706d5
5e60c7c
 
 
 
 
 
f328abe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import torch
import tempfile
import os
from PIL import Image
from transformers import AutoProcessor, HunYuanVLForConditionalGeneration

# ============================================================
# HunyuanOCR - Image Text Extraction
# ============================================================
MODEL_ID = "tencent/HunyuanOCR"
model = None
processor = None

def clean_repeated_substrings(text):
    n = len(text)
    if n < 8000:
        return text
    for length in range(2, n // 10 + 1):
        candidate = text[-length:]
        count = 0
        i = n - length
        while i >= 0 and text[i:i + length] == candidate:
            count += 1
            i -= length
        if count >= 10:
            return text[:n - length * (count - 1)]
    return text

def load_model():
    global model, processor
    if model is not None:
        return
    token = os.getenv("HF_TOKEN", None)
    print("Loading HunyuanOCR ...")
    processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False, token=token)
    model = HunYuanVLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        attn_implementation="eager",
        device_map=None,
        low_cpu_mem_usage=True,
        token=token,
    ).float()
    model.eval()
    print("HunyuanOCR loaded.")

def ocr_process(image):
    if image is None:
        return "Please upload an image."

    load_model()

    with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
        image.save(tmp.name)
        img_path = tmp.name

    try:
        messages = [
            {
                "role": "system",
                "content": ""
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": img_path},
                    {"type": "text", "text": "检测并识别图片中的文字,将文本坐标格式化输出。"}
                ]
            }
        ]

        text_prompt = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_input = Image.open(img_path)
        inputs = processor(
            text=[text_prompt], images=[image_input],
            padding=True, return_tensors="pt"
        )

        # The processor outputs bfloat16 tensors, but model is float32.
        # BatchFeature doesn't support in-place modification well,
        # so rebuild as a plain dict with float32 tensors.
        clean_inputs = {}
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                if v.dtype == torch.bfloat16:
                    clean_inputs[k] = v.to(torch.float32)
                else:
                    clean_inputs[k] = v
            else:
                clean_inputs[k] = v

        with torch.no_grad():
            generated_ids = model.generate(**clean_inputs, max_new_tokens=16384, do_sample=False)

        input_ids = clean_inputs["input_ids"]
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
        ]
        output_text = clean_repeated_substrings(
            processor.batch_decode(
                generated_ids_trimmed,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )[0]
        )

        return output_text
    finally:
        if os.path.exists(img_path):
            os.remove(img_path)

# ============================================================
# Gradio Interface
# ============================================================
with gr.Blocks(title="HunyuanOCR") as demo:
    gr.Markdown("""
    # HunyuanOCR - Text Extraction
    Upload an image and the model will detect and extract all text with coordinates.
    """)

    image_input = gr.Image(type="pil", label="Upload Image")
    ocr_output = gr.Textbox(label="Extracted Text", lines=15)
    ocr_btn = gr.Button("Extract Text", variant="primary")

    ocr_btn.click(ocr_process, image_input, ocr_output)
    image_input.change(ocr_process, image_input, ocr_output)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")