HayatoHongoEveryonesAI's picture
Update app.py
518e2dc verified
# app.py
import gradio as gr
import spaces
import torch
from PIL import Image
from vlm_inference import (
load_vlm_model,
vlm_infer_stream,
image_processor,
)
# =====================================================
# Load VLM on CPU (ZeroGPU)
# =====================================================
print("[DEBUG] Loading VLM model on CPU...")
model = load_vlm_model()
model.eval()
print("[DEBUG] VLM model loaded.")
# =====================================================
# message parser (multimodal=True 仕様準拠)
# =====================================================
def parse_message(message: dict):
"""
message = {
"text": str,
"files": list # PIL.Image が入る
}
"""
print("[DEBUG] parse_message called")
print("[DEBUG] message type:", type(message))
print("[DEBUG] message content:", message)
text = message.get("text", "")
files = message.get("files", [])
print("[DEBUG] parsed text:", repr(text))
print("[DEBUG] parsed files:", files)
image = files[0] if files else None
print("[DEBUG] parsed image:", image)
return text, image
# =====================================================
# GPU inference (single-turn, VLM only)
# =====================================================
@spaces.GPU
def chat_fn(message, history, temperature, top_p, top_k):
text, image = parse_message(message)
if image is None:
yield "Image input is required."
return
device = "cuda"
model_gpu = model.to(device)
if isinstance(image, str):
from PIL import Image
image = Image.open(image)
image_tensor = image_processor(
images=image.convert("RGB"),
return_tensors="pt"
)["pixel_values"].to(device)
prompt = f"<user>\n{text}<assistant>\n"
print("[DEBUG] prompt:", prompt)
# ★ ここが重要:累積して yield
output = ""
for chunk in vlm_infer_stream(
model=model_gpu,
image_tensor=image_tensor,
prompt=prompt,
max_new_tokens=256,
temperature=temperature,
top_p=top_p if top_p > 0 else None,
top_k=top_k if top_k > 0 else None,
):
output += chunk
yield output
model_gpu.to("cpu")
torch.cuda.empty_cache()
# =====================================================
# UI (ChatInterface, multimodal)
# =====================================================
print("[DEBUG] Building Gradio UI")
demo = gr.ChatInterface(
fn=chat_fn,
multimodal=True,
title="EveryonesGPT Vision Instruct. Single-turn English Only demo (CLIP ViT-L/14)",
description=(
"You must include an image."
"Download an example image."
"https://raw.githubusercontent.com/HayatoHongo/nanoGPTVision/main/hodomoe_cat.png\n"
"### Github Repo: https://github.com/HayatoHongo/nanoGPT-Vision.git\n"
"### **⚠️ The first message takes around 1 minute.**"
),
additional_inputs=[
gr.Slider(0.1, 2.0, value=0.2, step=0.05, label="Temperature"),
gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"),
gr.Slider(0, 200, value=0, step=1, label="Top-k"),
],
)
print("[DEBUG] Launching Gradio app")
demo.launch()