testOcr1 / app.py
BlackSpire's picture
Update app.py
2723f76 verified
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import torch
import os
# Disable any default demos
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
def clean_repeated_substrings(text):
"""Clean repeated substrings in text"""
n = len(text)
if n < 8000:
return text
for length in range(2, n // 10 + 1):
candidate = text[-length:]
count = 0
i = n - length
while i >= 0 and text[i:i + length] == candidate:
count += 1
i -= length
if count >= 10:
return text[:n - length * (count - 1)]
return text
# Load model and processor globally
model_name_or_path = "tencent/HunyuanOCR"
print("Loading model and processor...")
try:
processor = AutoProcessor.from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise
def process_image(image, prompt_text):
"""Process image and return OCR results"""
if image is None:
return "Please upload an image first."
try:
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Use custom prompt if provided, otherwise use default
if not prompt_text or prompt_text.strip() == "":
prompt_text = "检测并识别图片中的文字,将文本坐标格式化输出。"
# Prepare messages
messages = [
{"role": "system", "content": ""},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt_text},
],
}
]
# Process input
text = processor.apply_chat_template([messages], tokenize=False, add_generation_prompt=True)[0]
inputs = processor(
text=[text],
images=image,
padding=True,
return_tensors="pt",
)
# Generate output
with torch.no_grad():
device = next(model.parameters()).device
inputs = inputs.to(device)
generated_ids = model.generate(**inputs, max_new_tokens=16384, do_sample=False)
# Decode output
if "input_ids" in inputs:
input_ids = inputs.input_ids
else:
input_ids = inputs.inputs
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
]
output_texts = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
# Clean and return result
result = clean_repeated_substrings(output_texts[0])
return result
except Exception as e:
return f"Error processing image: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="HunyuanOCR Web App") as demo:
gr.Markdown("# 🔍 HunyuanOCR - Text Detection & Recognition")
gr.Markdown("Upload an image to detect and recognize text with coordinates.")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload Image",
type="pil"
)
prompt_input = gr.Textbox(
label="Custom Prompt (Optional)",
placeholder="检测并识别图片中的文字,将文本坐标格式化输出。",
lines=3
)
process_btn = gr.Button("Process Image", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="OCR Results",
lines=20,
placeholder="Results will appear here..."
)
# Examples
gr.Markdown("### Usage Tips:")
gr.Markdown("""
- Upload an image containing text
- Optionally customize the prompt for different OCR tasks
- Click 'Process Image' to get results
- Default prompt detects and recognizes text with formatted coordinates
""")
# Connect button to processing function
process_btn.click(
fn=process_image,
inputs=[image_input, prompt_input],
outputs=output_text
)
# Launch the app
if __name__ == "__main__":
demo.launch()