import os import gc import copy import torch import gradio as gr import spaces from argparse import ArgumentParser from transformers import AutoProcessor, HunYuanVLForConditionalGeneration from qwen_vl_utils import process_vision_info # ------------------------------------------------------------------- # Global config # ------------------------------------------------------------------- os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" def _get_args(): parser = ArgumentParser() parser.add_argument( "-c", "--checkpoint-path", type=str, default="tencent/HunyuanOCR", help="Checkpoint name or path", ) parser.add_argument("--cpu-only", action="store_true") parser.add_argument("--share", action="store_true") parser.add_argument("--inbrowser", action="store_true") return parser.parse_args() def _load_model_processor(args): print("[INFO] Loading model") print("[INFO] CUDA available:", torch.cuda.is_available()) model = HunYuanVLForConditionalGeneration.from_pretrained( args.checkpoint_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto", ) if hasattr(model, "gradient_checkpointing_disable"): model.gradient_checkpointing_disable() print("[INFO] Gradient checkpointing disabled") model.eval() processor = AutoProcessor.from_pretrained( args.checkpoint_path, use_fast=False, trust_remote_code=True ) print("[INFO] Model device:", next(model.parameters()).device) return model, processor def _parse_text(text: str) -> str: if not isinstance(text, str): text = str(text) return text.replace("", "").replace("", "") def clean_repeated_substrings(text: str) -> str: n = len(text) if n < 2000: return text for length in range(2, n // 10 + 1): candidate = text[-length:] count = 0 i = n - length while i >= 0 and text[i : i + length] == candidate: count += 1 i -= length if count >= 10: return text[: n - length * (count - 1)] return text def _gc(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def build_hunyuan_messages_from_history(history, image_path, latest_user_text): """ history: list of [user_text, assistant_text] pairs from ChatInterface image_path: current uploaded image file path (or None) latest_user_text: current user message (str) Returns: list[{"role": ..., "content": [...]}] for HunYuan """ messages = [] # 1) Past turns (only text – image reused only for current turn) for user, assistant in history: # user messages.append( { "role": "user", "content": [{"type": "text", "text": user}], } ) # assistant messages.append( { "role": "assistant", "content": [{"type": "text", "text": assistant}], } ) # 2) Current user turn (image + text) content = [] if image_path: content.append( { "type": "image", "image": os.path.abspath(image_path), } ) if latest_user_text: content.append({"type": "text", "text": latest_user_text}) if content: messages.append({"role": "user", "content": content}) return messages def main(): args = _get_args() model, processor = _load_model_processor(args) # ------------------------- # Core model call # ------------------------- @spaces.GPU(duration=120) def call_local_model(hy_messages): import time start_time = time.time() # HunYuan expects list[list[message]] convs = [hy_messages] texts = [ processor.apply_chat_template( c, tokenize=False, add_generation_prompt=True ) for c in convs ] image_inputs, video_inputs = process_vision_info(convs) inputs = processor( text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) device = "cuda" if torch.cuda.is_available() else "cpu" inputs = inputs.to(device) max_new_tokens = 512 # keep this smaller on CPU with torch.no_grad(): if device == "cuda": with torch.cuda.amp.autocast(dtype=torch.bfloat16): _ = model(**inputs, use_cache=False) else: _ = model(**inputs, use_cache=False) generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0, ) input_ids = inputs.input_ids generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids) ] output_texts = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) text = clean_repeated_substrings(_parse_text(output_texts[0])) print(f"[DEBUG] Total generation time: {time.time() - start_time:.2f}s") return text # ------------------------- # Chat handler for ChatInterface # ------------------------- def ocr_chat(message, history, image_path): """ message: current user text (str) history: list[[user, assistant], ...] image_path: filepath from Image component """ message = (message or "").strip() if not message and not image_path: return "Please upload an image and/or type a question." hy_messages = build_hunyuan_messages_from_history( history or [], image_path, message ) answer = call_local_model(hy_messages) return answer # ------------------------- # UI: ChatInterface + image # ------------------------- with gr.Blocks() as demo: gr.Markdown("# HunyuanOCR\nUpload an image and ask OCR questions.") chat = gr.ChatInterface( fn=ocr_chat, additional_inputs=[ gr.Image( label="Upload Image", type="filepath" ) ], title="Hunyuan OCR", description="Ask questions about the uploaded document", ) demo.launch() if __name__ == "__main__": main()