|
|
import os |
|
|
import gc |
|
|
import copy |
|
|
import torch |
|
|
import gradio as gr |
|
|
import spaces |
|
|
|
|
|
from argparse import ArgumentParser |
|
|
from transformers import AutoProcessor, HunYuanVLForConditionalGeneration |
|
|
from qwen_vl_utils import process_vision_info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" |
|
|
|
|
|
|
|
|
def _get_args(): |
|
|
parser = ArgumentParser() |
|
|
parser.add_argument( |
|
|
"-c", |
|
|
"--checkpoint-path", |
|
|
type=str, |
|
|
default="tencent/HunyuanOCR", |
|
|
help="Checkpoint name or path", |
|
|
) |
|
|
parser.add_argument("--cpu-only", action="store_true") |
|
|
parser.add_argument("--share", action="store_true") |
|
|
parser.add_argument("--inbrowser", action="store_true") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def _load_model_processor(args): |
|
|
print("[INFO] Loading model") |
|
|
print("[INFO] CUDA available:", torch.cuda.is_available()) |
|
|
|
|
|
model = HunYuanVLForConditionalGeneration.from_pretrained( |
|
|
args.checkpoint_path, |
|
|
attn_implementation="eager", |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
if hasattr(model, "gradient_checkpointing_disable"): |
|
|
model.gradient_checkpointing_disable() |
|
|
print("[INFO] Gradient checkpointing disabled") |
|
|
|
|
|
model.eval() |
|
|
processor = AutoProcessor.from_pretrained( |
|
|
args.checkpoint_path, use_fast=False, trust_remote_code=True |
|
|
) |
|
|
|
|
|
print("[INFO] Model device:", next(model.parameters()).device) |
|
|
return model, processor |
|
|
|
|
|
|
|
|
def _parse_text(text: str) -> str: |
|
|
if not isinstance(text, str): |
|
|
text = str(text) |
|
|
return text.replace("<trans>", "").replace("</trans>", "") |
|
|
|
|
|
|
|
|
def clean_repeated_substrings(text: str) -> str: |
|
|
n = len(text) |
|
|
if n < 2000: |
|
|
return text |
|
|
for length in range(2, n // 10 + 1): |
|
|
candidate = text[-length:] |
|
|
count = 0 |
|
|
i = n - length |
|
|
while i >= 0 and text[i : i + length] == candidate: |
|
|
count += 1 |
|
|
i -= length |
|
|
if count >= 10: |
|
|
return text[: n - length * (count - 1)] |
|
|
return text |
|
|
|
|
|
|
|
|
def _gc(): |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
def build_hunyuan_messages_from_history(history, image_path, latest_user_text): |
|
|
""" |
|
|
history: list of [user_text, assistant_text] pairs from ChatInterface |
|
|
image_path: current uploaded image file path (or None) |
|
|
latest_user_text: current user message (str) |
|
|
Returns: list[{"role": ..., "content": [...]}] for HunYuan |
|
|
""" |
|
|
messages = [] |
|
|
|
|
|
|
|
|
for user, assistant in history: |
|
|
|
|
|
messages.append( |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [{"type": "text", "text": user}], |
|
|
} |
|
|
) |
|
|
|
|
|
messages.append( |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": [{"type": "text", "text": assistant}], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
content = [] |
|
|
if image_path: |
|
|
content.append( |
|
|
{ |
|
|
"type": "image", |
|
|
"image": os.path.abspath(image_path), |
|
|
} |
|
|
) |
|
|
if latest_user_text: |
|
|
content.append({"type": "text", "text": latest_user_text}) |
|
|
|
|
|
if content: |
|
|
messages.append({"role": "user", "content": content}) |
|
|
|
|
|
return messages |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = _get_args() |
|
|
model, processor = _load_model_processor(args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def call_local_model(hy_messages): |
|
|
import time |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
convs = [hy_messages] |
|
|
|
|
|
texts = [ |
|
|
processor.apply_chat_template( |
|
|
c, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
for c in convs |
|
|
] |
|
|
|
|
|
image_inputs, video_inputs = process_vision_info(convs) |
|
|
|
|
|
inputs = processor( |
|
|
text=texts, |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
inputs = inputs.to(device) |
|
|
|
|
|
max_new_tokens = 512 |
|
|
with torch.no_grad(): |
|
|
if device == "cuda": |
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
_ = model(**inputs, use_cache=False) |
|
|
else: |
|
|
_ = model(**inputs, use_cache=False) |
|
|
|
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=False, |
|
|
temperature=0, |
|
|
) |
|
|
|
|
|
input_ids = inputs.input_ids |
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids) |
|
|
] |
|
|
output_texts = processor.batch_decode( |
|
|
generated_ids_trimmed, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False, |
|
|
) |
|
|
text = clean_repeated_substrings(_parse_text(output_texts[0])) |
|
|
print(f"[DEBUG] Total generation time: {time.time() - start_time:.2f}s") |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ocr_chat(message, history, image_path): |
|
|
""" |
|
|
message: current user text (str) |
|
|
history: list[[user, assistant], ...] |
|
|
image_path: filepath from Image component |
|
|
""" |
|
|
message = (message or "").strip() |
|
|
|
|
|
if not message and not image_path: |
|
|
return "Please upload an image and/or type a question." |
|
|
|
|
|
hy_messages = build_hunyuan_messages_from_history( |
|
|
history or [], image_path, message |
|
|
) |
|
|
answer = call_local_model(hy_messages) |
|
|
return answer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# HunyuanOCR\nUpload an image and ask OCR questions.") |
|
|
|
|
|
chat = gr.ChatInterface( |
|
|
fn=ocr_chat, |
|
|
additional_inputs=[ |
|
|
gr.Image( |
|
|
label="Upload Image", |
|
|
type="filepath" |
|
|
) |
|
|
], |
|
|
title="Hunyuan OCR", |
|
|
description="Ask questions about the uploaded document", |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|