import gradio as gr from transformers import AutoProcessor, AutoModelForVision2Seq from PIL import Image import torch import os # Disable any default demos os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' def clean_repeated_substrings(text): """Clean repeated substrings in text""" n = len(text) if n < 8000: return text for length in range(2, n // 10 + 1): candidate = text[-length:] count = 0 i = n - length while i >= 0 and text[i:i + length] == candidate: count += 1 i -= length if count >= 10: return text[:n - length * (count - 1)] return text # Load model and processor globally model_name_or_path = "tencent/HunyuanOCR" print("Loading model and processor...") try: processor = AutoProcessor.from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True) model = AutoModelForVision2Seq.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise def process_image(image, prompt_text): """Process image and return OCR results""" if image is None: return "Please upload an image first." try: # Convert to PIL Image if needed if not isinstance(image, Image.Image): image = Image.fromarray(image) # Use custom prompt if provided, otherwise use default if not prompt_text or prompt_text.strip() == "": prompt_text = "检测并识别图片中的文字,将文本坐标格式化输出。" # Prepare messages messages = [ {"role": "system", "content": ""}, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt_text}, ], } ] # Process input text = processor.apply_chat_template([messages], tokenize=False, add_generation_prompt=True)[0] inputs = processor( text=[text], images=image, padding=True, return_tensors="pt", ) # Generate output with torch.no_grad(): device = next(model.parameters()).device inputs = inputs.to(device) generated_ids = model.generate(**inputs, max_new_tokens=16384, do_sample=False) # Decode output if "input_ids" in inputs: input_ids = inputs.input_ids else: input_ids = inputs.inputs generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids) ] output_texts = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) # Clean and return result result = clean_repeated_substrings(output_texts[0]) return result except Exception as e: return f"Error processing image: {str(e)}" # Create Gradio interface with gr.Blocks(title="HunyuanOCR Web App") as demo: gr.Markdown("# 🔍 HunyuanOCR - Text Detection & Recognition") gr.Markdown("Upload an image to detect and recognize text with coordinates.") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( label="Upload Image", type="pil" ) prompt_input = gr.Textbox( label="Custom Prompt (Optional)", placeholder="检测并识别图片中的文字,将文本坐标格式化输出。", lines=3 ) process_btn = gr.Button("Process Image", variant="primary") with gr.Column(scale=1): output_text = gr.Textbox( label="OCR Results", lines=20, placeholder="Results will appear here..." ) # Examples gr.Markdown("### Usage Tips:") gr.Markdown(""" - Upload an image containing text - Optionally customize the prompt for different OCR tasks - Click 'Process Image' to get results - Default prompt detects and recognizes text with formatted coordinates """) # Connect button to processing function process_btn.click( fn=process_image, inputs=[image_input, prompt_input], outputs=output_text ) # Launch the app if __name__ == "__main__": demo.launch()