File size: 4,754 Bytes
98b39f6
2723f76
 
 
 
98b39f6
2723f76
 
1c0b21e
020cf60
c1a9e47
020cf60
2723f76
020cf60
 
c1a9e47
020cf60
 
c1a9e47
020cf60
 
 
 
c1a9e47
020cf60
 
2723f76
 
 
98b39f6
2723f76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1a9e47
 
2723f76
 
c1a9e47
2723f76
 
c1a9e47
 
020cf60
c1a9e47
2723f76
c1a9e47
2723f76
 
 
c1a9e47
 
 
 
 
 
2723f76
c1a9e47
 
 
 
 
2723f76
 
 
c1a9e47
 
2723f76
 
 
c1a9e47
2723f76
 
98b39f6
2723f76
 
 
 
c1a9e47
2723f76
 
 
 
 
98b39f6
2723f76
 
 
 
 
 
98b39f6
2723f76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020cf60
 
2723f76
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import torch
import os

# Disable any default demos
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'

def clean_repeated_substrings(text):
    """Clean repeated substrings in text"""
    n = len(text)
    if n < 8000:
        return text
    for length in range(2, n // 10 + 1):
        candidate = text[-length:] 
        count = 0
        i = n - length
        
        while i >= 0 and text[i:i + length] == candidate:
            count += 1
            i -= length
        if count >= 10:
            return text[:n - length * (count - 1)]  
    return text

# Load model and processor globally
model_name_or_path = "tencent/HunyuanOCR"
print("Loading model and processor...")

try:
    processor = AutoProcessor.from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
    model = AutoModelForVision2Seq.from_pretrained(
        model_name_or_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

def process_image(image, prompt_text):
    """Process image and return OCR results"""
    if image is None:
        return "Please upload an image first."
    
    try:
        # Convert to PIL Image if needed
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        
        # Use custom prompt if provided, otherwise use default
        if not prompt_text or prompt_text.strip() == "":
            prompt_text = "检测并识别图片中的文字,将文本坐标格式化输出。"
        
        # Prepare messages
        messages = [
            {"role": "system", "content": ""},
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": prompt_text},
                ],
            }
        ]
        
        # Process input
        text = processor.apply_chat_template([messages], tokenize=False, add_generation_prompt=True)[0]
        inputs = processor(
            text=[text],
            images=image,
            padding=True,
            return_tensors="pt",
        )
        
        # Generate output
        with torch.no_grad():
            device = next(model.parameters()).device
            inputs = inputs.to(device)
            generated_ids = model.generate(**inputs, max_new_tokens=16384, do_sample=False)
        
        # Decode output
        if "input_ids" in inputs:
            input_ids = inputs.input_ids
        else:
            input_ids = inputs.inputs
            
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
        ]
        
        output_texts = processor.batch_decode(
            generated_ids_trimmed, 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=False
        )
        
        # Clean and return result
        result = clean_repeated_substrings(output_texts[0])
        return result
        
    except Exception as e:
        return f"Error processing image: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="HunyuanOCR Web App") as demo:
    gr.Markdown("# 🔍 HunyuanOCR - Text Detection & Recognition")
    gr.Markdown("Upload an image to detect and recognize text with coordinates.")
    
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(
                label="Upload Image",
                type="pil"
            )
            prompt_input = gr.Textbox(
                label="Custom Prompt (Optional)",
                placeholder="检测并识别图片中的文字,将文本坐标格式化输出。",
                lines=3
            )
            process_btn = gr.Button("Process Image", variant="primary")
            
        with gr.Column(scale=1):
            output_text = gr.Textbox(
                label="OCR Results",
                lines=20,
                placeholder="Results will appear here..."
            )
    
    # Examples
    gr.Markdown("### Usage Tips:")
    gr.Markdown("""
    - Upload an image containing text
    - Optionally customize the prompt for different OCR tasks
    - Click 'Process Image' to get results
    - Default prompt detects and recognizes text with formatted coordinates
    """)
    
    # Connect button to processing function
    process_btn.click(
        fn=process_image,
        inputs=[image_input, prompt_input],
        outputs=output_text
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()