File size: 3,226 Bytes
81845c2
7657160
81845c2
 
5e541e3
 
f6e0b49
 
6b1d95e
f6e0b49
 
 
 
 
ff05d0c
f6e0b49
ff05d0c
cf3c4b2
 
ff05d0c
81845c2
f6e0b49
 
ff05d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e0b49
81845c2
aa68cd8
ff05d0c
 
f6e0b49
aa68cd8
 
81845c2
6b1d95e
 
 
5e541e3
aa68cd8
5e541e3
 
6b1d95e
 
 
 
f5e1816
89d297d
877afe0
6b1d95e
4eec729
 
 
aa68cd8
 
 
 
 
 
 
 
 
4eec729
 
aa68cd8
 
 
 
f6e0b49
 
4eec729
f6e0b49
ff05d0c
f6e0b49
ff05d0c
 
 
 
 
a495d63
07b501a
 
 
 
518e2dc
07b501a
 
ff05d0c
5834da2
ff05d0c
 
 
 
81845c2
ff05d0c
f6e0b49
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# app.py
import gradio as gr
import spaces
import torch
from PIL import Image


from vlm_inference import (
    load_vlm_model,
    vlm_infer_stream,
    image_processor,
)

# =====================================================
# Load VLM on CPU (ZeroGPU)
# =====================================================
print("[DEBUG] Loading VLM model on CPU...")
model = load_vlm_model()
model.eval()
print("[DEBUG] VLM model loaded.")


# =====================================================
# message parser (multimodal=True 仕様準拠)
# =====================================================
def parse_message(message: dict):
    """
    message = {
        "text": str,
        "files": list  # PIL.Image が入る
    }
    """
    print("[DEBUG] parse_message called")
    print("[DEBUG] message type:", type(message))
    print("[DEBUG] message content:", message)

    text = message.get("text", "")
    files = message.get("files", [])

    print("[DEBUG] parsed text:", repr(text))
    print("[DEBUG] parsed files:", files)

    image = files[0] if files else None
    print("[DEBUG] parsed image:", image)

    return text, image


# =====================================================
# GPU inference (single-turn, VLM only)
# =====================================================
@spaces.GPU
def chat_fn(message, history, temperature, top_p, top_k):
    text, image = parse_message(message)

    if image is None:
        yield "Image input is required."
        return

    device = "cuda"
    model_gpu = model.to(device)

    if isinstance(image, str):
        from PIL import Image
        image = Image.open(image)

    image_tensor = image_processor(
        images=image.convert("RGB"),
        return_tensors="pt"
    )["pixel_values"].to(device)

    prompt = f"<user>\n{text}<assistant>\n"
    print("[DEBUG] prompt:", prompt)

    # ★ ここが重要:累積して yield
    output = ""

    for chunk in vlm_infer_stream(
        model=model_gpu,
        image_tensor=image_tensor,
        prompt=prompt,
        max_new_tokens=256,
        temperature=temperature,
        top_p=top_p if top_p > 0 else None,
        top_k=top_k if top_k > 0 else None,
    ):
        output += chunk
        yield output

    model_gpu.to("cpu")
    torch.cuda.empty_cache()




# =====================================================
# UI (ChatInterface, multimodal)
# =====================================================
print("[DEBUG] Building Gradio UI")

demo = gr.ChatInterface(
    fn=chat_fn,
    multimodal=True,
    title="EveryonesGPT Vision Instruct. Single-turn English Only demo (CLIP ViT-L/14)",
    description=(
    "You must include an image."
    "Download an example image."
    "https://raw.githubusercontent.com/HayatoHongo/nanoGPTVision/main/hodomoe_cat.png\n"
    "### Github Repo: https://github.com/HayatoHongo/nanoGPT-Vision.git\n"
    "### **⚠️ The first message takes around 1 minute.**"
    ),
    additional_inputs=[
        gr.Slider(0.1, 2.0, value=0.2, step=0.05, label="Temperature"),
        gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"),
        gr.Slider(0, 200, value=0, step=1, label="Top-k"),
    ],
)

print("[DEBUG] Launching Gradio app")
demo.launch()