HunyuanEz / app.py
KarthiEz's picture
Update app.py
d2146b6 verified
import os
import gc
import copy
import torch
import gradio as gr
import spaces
from argparse import ArgumentParser
from transformers import AutoProcessor, HunYuanVLForConditionalGeneration
from qwen_vl_utils import process_vision_info
# -------------------------------------------------------------------
# Global config
# -------------------------------------------------------------------
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
def _get_args():
parser = ArgumentParser()
parser.add_argument(
"-c",
"--checkpoint-path",
type=str,
default="tencent/HunyuanOCR",
help="Checkpoint name or path",
)
parser.add_argument("--cpu-only", action="store_true")
parser.add_argument("--share", action="store_true")
parser.add_argument("--inbrowser", action="store_true")
return parser.parse_args()
def _load_model_processor(args):
print("[INFO] Loading model")
print("[INFO] CUDA available:", torch.cuda.is_available())
model = HunYuanVLForConditionalGeneration.from_pretrained(
args.checkpoint_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
device_map="auto",
)
if hasattr(model, "gradient_checkpointing_disable"):
model.gradient_checkpointing_disable()
print("[INFO] Gradient checkpointing disabled")
model.eval()
processor = AutoProcessor.from_pretrained(
args.checkpoint_path, use_fast=False, trust_remote_code=True
)
print("[INFO] Model device:", next(model.parameters()).device)
return model, processor
def _parse_text(text: str) -> str:
if not isinstance(text, str):
text = str(text)
return text.replace("<trans>", "").replace("</trans>", "")
def clean_repeated_substrings(text: str) -> str:
n = len(text)
if n < 2000:
return text
for length in range(2, n // 10 + 1):
candidate = text[-length:]
count = 0
i = n - length
while i >= 0 and text[i : i + length] == candidate:
count += 1
i -= length
if count >= 10:
return text[: n - length * (count - 1)]
return text
def _gc():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def build_hunyuan_messages_from_history(history, image_path, latest_user_text):
"""
history: list of [user_text, assistant_text] pairs from ChatInterface
image_path: current uploaded image file path (or None)
latest_user_text: current user message (str)
Returns: list[{"role": ..., "content": [...]}] for HunYuan
"""
messages = []
# 1) Past turns (only text – image reused only for current turn)
for user, assistant in history:
# user
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": user}],
}
)
# assistant
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": assistant}],
}
)
# 2) Current user turn (image + text)
content = []
if image_path:
content.append(
{
"type": "image",
"image": os.path.abspath(image_path),
}
)
if latest_user_text:
content.append({"type": "text", "text": latest_user_text})
if content:
messages.append({"role": "user", "content": content})
return messages
def main():
args = _get_args()
model, processor = _load_model_processor(args)
# -------------------------
# Core model call
# -------------------------
@spaces.GPU(duration=120)
def call_local_model(hy_messages):
import time
start_time = time.time()
# HunYuan expects list[list[message]]
convs = [hy_messages]
texts = [
processor.apply_chat_template(
c, tokenize=False, add_generation_prompt=True
)
for c in convs
]
image_inputs, video_inputs = process_vision_info(convs)
inputs = processor(
text=texts,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = inputs.to(device)
max_new_tokens = 512 # keep this smaller on CPU
with torch.no_grad():
if device == "cuda":
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
_ = model(**inputs, use_cache=False)
else:
_ = model(**inputs, use_cache=False)
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=0,
)
input_ids = inputs.input_ids
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)
]
output_texts = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
text = clean_repeated_substrings(_parse_text(output_texts[0]))
print(f"[DEBUG] Total generation time: {time.time() - start_time:.2f}s")
return text
# -------------------------
# Chat handler for ChatInterface
# -------------------------
def ocr_chat(message, history, image_path):
"""
message: current user text (str)
history: list[[user, assistant], ...]
image_path: filepath from Image component
"""
message = (message or "").strip()
if not message and not image_path:
return "Please upload an image and/or type a question."
hy_messages = build_hunyuan_messages_from_history(
history or [], image_path, message
)
answer = call_local_model(hy_messages)
return answer
# -------------------------
# UI: ChatInterface + image
# -------------------------
with gr.Blocks() as demo:
gr.Markdown("# HunyuanOCR\nUpload an image and ask OCR questions.")
chat = gr.ChatInterface(
fn=ocr_chat,
additional_inputs=[
gr.Image(
label="Upload Image",
type="filepath"
)
],
title="Hunyuan OCR",
description="Ask questions about the uploaded document",
)
demo.launch()
if __name__ == "__main__":
main()