File size: 3,226 Bytes
81845c2 7657160 81845c2 5e541e3 f6e0b49 6b1d95e f6e0b49 ff05d0c f6e0b49 ff05d0c cf3c4b2 ff05d0c 81845c2 f6e0b49 ff05d0c f6e0b49 81845c2 aa68cd8 ff05d0c f6e0b49 aa68cd8 81845c2 6b1d95e 5e541e3 aa68cd8 5e541e3 6b1d95e f5e1816 89d297d 877afe0 6b1d95e 4eec729 aa68cd8 4eec729 aa68cd8 f6e0b49 4eec729 f6e0b49 ff05d0c f6e0b49 ff05d0c a495d63 07b501a 518e2dc 07b501a ff05d0c 5834da2 ff05d0c 81845c2 ff05d0c f6e0b49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# app.py
import gradio as gr
import spaces
import torch
from PIL import Image
from vlm_inference import (
load_vlm_model,
vlm_infer_stream,
image_processor,
)
# =====================================================
# Load VLM on CPU (ZeroGPU)
# =====================================================
print("[DEBUG] Loading VLM model on CPU...")
model = load_vlm_model()
model.eval()
print("[DEBUG] VLM model loaded.")
# =====================================================
# message parser (multimodal=True 仕様準拠)
# =====================================================
def parse_message(message: dict):
"""
message = {
"text": str,
"files": list # PIL.Image が入る
}
"""
print("[DEBUG] parse_message called")
print("[DEBUG] message type:", type(message))
print("[DEBUG] message content:", message)
text = message.get("text", "")
files = message.get("files", [])
print("[DEBUG] parsed text:", repr(text))
print("[DEBUG] parsed files:", files)
image = files[0] if files else None
print("[DEBUG] parsed image:", image)
return text, image
# =====================================================
# GPU inference (single-turn, VLM only)
# =====================================================
@spaces.GPU
def chat_fn(message, history, temperature, top_p, top_k):
text, image = parse_message(message)
if image is None:
yield "Image input is required."
return
device = "cuda"
model_gpu = model.to(device)
if isinstance(image, str):
from PIL import Image
image = Image.open(image)
image_tensor = image_processor(
images=image.convert("RGB"),
return_tensors="pt"
)["pixel_values"].to(device)
prompt = f"<user>\n{text}<assistant>\n"
print("[DEBUG] prompt:", prompt)
# ★ ここが重要:累積して yield
output = ""
for chunk in vlm_infer_stream(
model=model_gpu,
image_tensor=image_tensor,
prompt=prompt,
max_new_tokens=256,
temperature=temperature,
top_p=top_p if top_p > 0 else None,
top_k=top_k if top_k > 0 else None,
):
output += chunk
yield output
model_gpu.to("cpu")
torch.cuda.empty_cache()
# =====================================================
# UI (ChatInterface, multimodal)
# =====================================================
print("[DEBUG] Building Gradio UI")
demo = gr.ChatInterface(
fn=chat_fn,
multimodal=True,
title="EveryonesGPT Vision Instruct. Single-turn English Only demo (CLIP ViT-L/14)",
description=(
"You must include an image."
"Download an example image."
"https://raw.githubusercontent.com/HayatoHongo/nanoGPTVision/main/hodomoe_cat.png\n"
"### Github Repo: https://github.com/HayatoHongo/nanoGPT-Vision.git\n"
"### **⚠️ The first message takes around 1 minute.**"
),
additional_inputs=[
gr.Slider(0.1, 2.0, value=0.2, step=0.05, label="Temperature"),
gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"),
gr.Slider(0, 200, value=0, step=1, label="Top-k"),
],
)
print("[DEBUG] Launching Gradio app")
demo.launch()
|