run with api is too slow
use fastapi run with transformers too slow ( about 10s+), I have a GPU, how to solve the problem, my code is:
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from transformers import AutoProcessor, HunYuanVLForConditionalGeneration
from PIL import Image
import torch
import io
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
app = FastAPI(title="HunYuan OCR API")
def clean_repeated_substrings(text):
"""Clean repeated substrings in text"""
n = len(text)
if n < 8000:
return text
for length in range(2, n // 10 + 1):
candidate = text[-length:]
count = 0
i = n - length
while i >= 0 and text[i:i + length] == candidate:
count += 1
i -= length
if count >= 10:
return text[:n - length * (count - 1)]
return text
---------- Load model & processor once (global) ----------
model_name = "/models/HunyuanOCR"
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
model = HunYuanVLForConditionalGeneration.from_pretrained(
model_name,
attn_implementation="eager",
dtype=torch.bfloat16,
device_map="auto"
)
device = next(model.parameters()).device
-------------------- API Route --------------------
@app
.post("/ocr")
async def ocr_api(file: UploadFile = File(...)):
try:
# Read image bytes
img_bytes = await file.read()
image = Image.open(io.BytesIO(img_bytes))
# Construct messages
messages1 = [
{"role": "system", "content": ""},
{
"role": "user",
"content": [
{"type": "image", "image": "uploaded_image"},
{"type": "text", "text": "检测并识别图片中的文字,将文本坐标格式化输出。"},
],
},
]
messages = [messages1]
texts = [
processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
for msg in messages
]
# Prepare inputs
inputs = processor(
text=texts,
images=image,
padding=True,
return_tensors="pt",
).to(device)
# Generate
with torch.no_grad():
generated_ids = model.generate(
**inputs, max_new_tokens=1024, do_sample=False
)
input_ids = inputs.input_ids
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
output_text = clean_repeated_substrings(output_text)
return JSONResponse({"text": output_text})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
Hi, check this out for hunyuanOCR serving by inference provider. https://console.gmicloud.ai/playground/llm/hunyuanocr/3de77397-542f-49d5-830b-4a6c73811f88