File size: 5,069 Bytes
c6110b4
279f604
a709033
 
89f55ad
24c4395
89f55ad
 
 
346da8b
89f55ad
b2016f5
 
 
c6110b4
b2016f5
c6110b4
b2016f5
24c4395
346da8b
 
 
 
 
 
 
 
 
 
 
 
b2016f5
 
 
 
89f55ad
 
a709033
b2016f5
 
 
 
 
a709033
89f55ad
 
b2016f5
 
 
 
a709033
b2016f5
 
 
a709033
346da8b
 
 
c6110b4
 
 
 
 
 
 
 
 
 
 
89f55ad
346da8b
 
 
89f55ad
346da8b
 
 
 
 
 
 
a709033
b2016f5
346da8b
b2016f5
24c4395
 
 
 
a709033
c6110b4
 
 
 
 
 
 
 
 
 
 
 
 
24c4395
b2016f5
24c4395
 
 
 
 
 
89f55ad
c6110b4
89f55ad
 
24c4395
a709033
c6110b4
24c4395
 
 
 
c6110b4
 
 
 
 
 
 
 
 
346da8b
c6110b4
 
 
 
24c4395
89f55ad
346da8b
89f55ad
 
c6110b4
346da8b
c6110b4
24c4395
 
 
c6110b4
 
 
 
 
24c4395
346da8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# app.py
import os
import sys
import types
import importlib.machinery
from PIL import Image
import gradio as gr

# ===============================
# Helper to create package-like dummy modules
# ===============================
def _mk_pkg(name: str):
    m = types.ModuleType(name)
    spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
    spec.submodule_search_locations = []
    m.__spec__ = spec
    m.__path__ = []
    return m

# ===============================
# Disable GPU-only/optional paths
# ===============================
os.environ.setdefault("FLASH_ATTENTION", "0")
os.environ.setdefault("XFORMERS_DISABLED", "1")
os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
os.environ.setdefault("DISABLE_TRITON", "1")   # avoid triton kernels
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")  # uncomment to force CPU

# ===============================
# flash_attn stubs (package + submodules)
# ===============================
flash_attn_pkg = _mk_pkg("flash_attn")

flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
    "flash_attn.flash_attn_interface", loader=None
)

flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
    "flash_attn.bert_padding", loader=None
)

def _dummy_func(*args, **kwargs):
    raise RuntimeError("flash_attn is not available in this environment.")

flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
flash_attn_bert_padding.pad_input = _dummy_func
flash_attn_bert_padding.unpad_input = _dummy_func

sys.modules["flash_attn"] = flash_attn_pkg
sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding

# ===============================
# ps3 stub (optional vision tower)
# ===============================
ps3_pkg = _mk_pkg("ps3")
class _PS3Config: pass
class _PS3VisionConfig: pass
class _PS3ImageProcessor: pass
class _PS3VisionModel: pass
ps3_pkg.PS3Config = _PS3Config
ps3_pkg.PS3VisionConfig = _PS3VisionConfig
ps3_pkg.PS3ImageProcessor = _PS3ImageProcessor
ps3_pkg.PS3VisionModel = _PS3VisionModel
sys.modules["ps3"] = ps3_pkg

# ===============================
# Quantization stub to avoid Triton path
# VILA falls back to "from FloatPointQuantizeTorch import *" if Triton import fails.
# Provide a tiny no-op module so imports succeed.
# ===============================
fpqt = types.ModuleType("FloatPointQuantizeTorch")
def _id(x, *a, **k): return x  # identity
# names used by llava.model.qfunction
fpqt.block_cut = _id
fpqt.block_quant = _id
fpqt.block_reshape = _id
sys.modules["FloatPointQuantizeTorch"] = fpqt

# ===============================
# Load VILA
# ===============================
from llava.model.builder import load_pretrained_model
from llava.constants import DEFAULT_IMAGE_TOKEN

MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"

try:
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        MODEL_PATH, model_name="", model_base=None
    )
except Exception as e:
    ERR = f"Failed to load model '{MODEL_PATH}': {e}"
    def _boot_error_ui():
        with gr.Blocks(title="VILA 1.5 3B – Error") as demo:
            gr.Markdown("### ❌ Model failed to load")
            gr.Markdown(ERR)
        demo.launch()
    _boot_error_ui()
    raise

# Fallback chat template if missing
if getattr(tokenizer, "chat_template", None) is None:
    tokenizer.chat_template = (
        "{% for message in messages %}{{ message['role'] | upper }}: "
        "{{ message['content'] }}\n{% endfor %}ASSISTANT:"
    )

# ===============================
# Inference
# ===============================
def vila_infer(image, prompt):
    if image is None:
        return "Please upload an image."
    if not prompt or not str(prompt).strip():
        prompt = "Please describe the image."

    pil = Image.fromarray(image).convert("RGB")

    try:
        out = model.generate_content(
            prompt=[{
                "from": "human",
                "value": [
                    {"type": "image", "value": pil},
                    {"type": "text", "value": prompt}
                ]
            }],
            generation_config=None
        )
        return str(out).strip()
    except Exception as e:
        return f"❌ Inference error: {e}"

# ===============================
# UI
# ===============================
with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
    gr.Markdown("## πŸ–ΌοΈ VILA-1.5-3B β€” Image Description Demo")
    gr.Markdown("Upload an image and press **Run**.")

    with gr.Row():
        img = gr.Image(type="numpy", label="Image", height=320)
        prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)

    run_btn = gr.Button("Run")
    out = gr.Textbox(label="Output", lines=10)

    run_btn.click(vila_infer, [img, prompt], out)

demo.launch()