|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
from vlm_inference import ( |
|
|
load_vlm_model, |
|
|
vlm_infer_stream, |
|
|
image_processor, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("[DEBUG] Loading VLM model on CPU...") |
|
|
model = load_vlm_model() |
|
|
model.eval() |
|
|
print("[DEBUG] VLM model loaded.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_message(message: dict): |
|
|
""" |
|
|
message = { |
|
|
"text": str, |
|
|
"files": list # PIL.Image が入る |
|
|
} |
|
|
""" |
|
|
print("[DEBUG] parse_message called") |
|
|
print("[DEBUG] message type:", type(message)) |
|
|
print("[DEBUG] message content:", message) |
|
|
|
|
|
text = message.get("text", "") |
|
|
files = message.get("files", []) |
|
|
|
|
|
print("[DEBUG] parsed text:", repr(text)) |
|
|
print("[DEBUG] parsed files:", files) |
|
|
|
|
|
image = files[0] if files else None |
|
|
print("[DEBUG] parsed image:", image) |
|
|
|
|
|
return text, image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def chat_fn(message, history, temperature, top_p, top_k): |
|
|
text, image = parse_message(message) |
|
|
|
|
|
if image is None: |
|
|
yield "Image input is required." |
|
|
return |
|
|
|
|
|
device = "cuda" |
|
|
model_gpu = model.to(device) |
|
|
|
|
|
if isinstance(image, str): |
|
|
from PIL import Image |
|
|
image = Image.open(image) |
|
|
|
|
|
image_tensor = image_processor( |
|
|
images=image.convert("RGB"), |
|
|
return_tensors="pt" |
|
|
)["pixel_values"].to(device) |
|
|
|
|
|
prompt = f"<user>\n{text}<assistant>\n" |
|
|
print("[DEBUG] prompt:", prompt) |
|
|
|
|
|
|
|
|
output = "" |
|
|
|
|
|
for chunk in vlm_infer_stream( |
|
|
model=model_gpu, |
|
|
image_tensor=image_tensor, |
|
|
prompt=prompt, |
|
|
max_new_tokens=256, |
|
|
temperature=temperature, |
|
|
top_p=top_p if top_p > 0 else None, |
|
|
top_k=top_k if top_k > 0 else None, |
|
|
): |
|
|
output += chunk |
|
|
yield output |
|
|
|
|
|
model_gpu.to("cpu") |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("[DEBUG] Building Gradio UI") |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=chat_fn, |
|
|
multimodal=True, |
|
|
title="EveryonesGPT Vision Instruct. Single-turn English Only demo (CLIP ViT-L/14)", |
|
|
description=( |
|
|
"You must include an image." |
|
|
"Download an example image." |
|
|
"https://raw.githubusercontent.com/HayatoHongo/nanoGPTVision/main/hodomoe_cat.png\n" |
|
|
"### Github Repo: https://github.com/HayatoHongo/nanoGPT-Vision.git\n" |
|
|
"### **⚠️ The first message takes around 1 minute.**" |
|
|
), |
|
|
additional_inputs=[ |
|
|
gr.Slider(0.1, 2.0, value=0.2, step=0.05, label="Temperature"), |
|
|
gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"), |
|
|
gr.Slider(0, 200, value=0, step=1, label="Top-k"), |
|
|
], |
|
|
) |
|
|
|
|
|
print("[DEBUG] Launching Gradio app") |
|
|
demo.launch() |
|
|
|