OppaAI's picture
Update app.py
9fc69b3 verified
raw
history blame
2.35 kB
import gradio as gr
import base64
from PIL import Image
import io
import json
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
# ------------------------------------------------------------
# 1. Load VLLM Model (Qwen3-VL-8B-Instruct)
# ------------------------------------------------------------
model_name = "Qwen/Qwen2-VL-7B-Instruct" # HF 官方推薦名稱(VL)
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForVision2Seq.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to("cuda")
# ------------------------------------------------------------
# 2. Main Process Function
# ------------------------------------------------------------
def process(payload):
try:
# 取得資料
data = payload
img_bytes = base64.b64decode(data["image_b64"])
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# ------------------------------------------------------------
# 3. Vision-Language model inference
# ------------------------------------------------------------
prompt = "Describe what you see in this image in detail."
inputs = processor(images=img, text=prompt, return_tensors="pt").to("cuda", torch.float16)
output_ids = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.2
)
response_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
# ------------------------------------------------------------
# 4. Return results to Jetson
# ------------------------------------------------------------
reply = {
"received": True,
"robot_id": data.get("robot_id"),
"size": img.size,
"vllm_analysis": response_text
}
return reply
except Exception as e:
return None, {"error": str(e)}
# ------------------------------------------------------------
# 5. Gradio UI
# ------------------------------------------------------------
demo = gr.Interface(
fn=process,
inputs=gr.JSON(label="Input Payload (Dict format)"),
outputs=[
gr.Image(type="pil", label="Image Preview"),
gr.JSON(label="Reply to Jetson")
],
api_name="predict"
)
demo.launch()