File size: 5,645 Bytes
c6110b4
279f604
a709033
 
89f55ad
24c4395
89f55ad
 
 
346da8b
89f55ad
b2016f5
 
 
c6110b4
b2016f5
c6110b4
b2016f5
24c4395
346da8b
 
 
 
 
 
 
 
 
 
 
 
b2016f5
 
 
 
89f55ad
 
a709033
b2016f5
 
 
 
 
a709033
89f55ad
 
b2016f5
 
 
 
a709033
b2016f5
 
 
a709033
346da8b
 
 
c6110b4
 
 
 
 
 
 
 
 
 
 
89f55ad
f3b369a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a709033
b2016f5
346da8b
b2016f5
24c4395
 
 
 
a709033
c6110b4
 
 
 
 
 
 
 
 
 
 
 
 
24c4395
b2016f5
24c4395
 
 
 
 
 
89f55ad
c6110b4
89f55ad
f3b369a
 
89f55ad
24c4395
a709033
c6110b4
24c4395
 
f3b369a
24c4395
c6110b4
 
 
 
 
 
 
 
 
f3b369a
c6110b4
 
 
 
24c4395
89f55ad
346da8b
89f55ad
 
c6110b4
346da8b
c6110b4
24c4395
 
 
c6110b4
 
 
 
 
24c4395
346da8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# app.py
import os
import sys
import types
import importlib.machinery
from PIL import Image
import gradio as gr

# ===============================
# Helper to create package-like dummy modules
# ===============================
def _mk_pkg(name: str):
    m = types.ModuleType(name)
    spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
    spec.submodule_search_locations = []
    m.__spec__ = spec
    m.__path__ = []
    return m

# ===============================
# Disable GPU-only/optional paths
# ===============================
os.environ.setdefault("FLASH_ATTENTION", "0")
os.environ.setdefault("XFORMERS_DISABLED", "1")
os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
os.environ.setdefault("DISABLE_TRITON", "1")   # avoid triton kernels
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")  # uncomment to force CPU

# ===============================
# flash_attn stubs (package + submodules)
# ===============================
flash_attn_pkg = _mk_pkg("flash_attn")

flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
    "flash_attn.flash_attn_interface", loader=None
)

flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
    "flash_attn.bert_padding", loader=None
)

def _dummy_func(*args, **kwargs):
    raise RuntimeError("flash_attn is not available in this environment.")

flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
flash_attn_bert_padding.pad_input = _dummy_func
flash_attn_bert_padding.unpad_input = _dummy_func

sys.modules["flash_attn"] = flash_attn_pkg
sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding

# ===============================
# ps3 stub (optional vision tower)
# ===============================
ps3_pkg = _mk_pkg("ps3")
class _PS3Config: pass
class _PS3VisionConfig: pass
class _PS3ImageProcessor: pass
class _PS3VisionModel: pass
ps3_pkg.PS3Config = _PS3Config
ps3_pkg.PS3VisionConfig = _PS3VisionConfig
ps3_pkg.PS3ImageProcessor = _PS3ImageProcessor
ps3_pkg.PS3VisionModel = _PS3VisionModel
sys.modules["ps3"] = ps3_pkg

# ===============================
# Quantization stubs to avoid Triton/Torch custom kernels
# VILA sometimes imports:
#  - from .FloatPointQuantizeTriton import *
#  - from FloatPointQuantizeTriton import *
#  - from FloatPointQuantizeTorch import *
# Provide both names (absolute and package-qualified) with no-op funcs.
# ===============================
def _mk_fpq_module(mod_name: str):
    mod = types.ModuleType(mod_name)
    # Provide the APIs qfunction expects
    def _id(x, *a, **k): return x
    mod.block_cut = _id
    mod.block_quant = _id
    mod.block_reshape = _id
    return mod

# Absolute names
sys.modules["FloatPointQuantizeTorch"] = _mk_fpq_module("FloatPointQuantizeTorch")
sys.modules["FloatPointQuantizeTriton"] = _mk_fpq_module("FloatPointQuantizeTriton")
# Package-qualified under llava.model
sys.modules["llava.model.FloatPointQuantizeTorch"] = sys.modules["FloatPointQuantizeTorch"]
sys.modules["llava.model.FloatPointQuantizeTriton"] = sys.modules["FloatPointQuantizeTriton"]

# ===============================
# Load VILA
# ===============================
from llava.model.builder import load_pretrained_model
from llava.constants import DEFAULT_IMAGE_TOKEN

MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"

try:
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        MODEL_PATH, model_name="", model_base=None
    )
except Exception as e:
    ERR = f"Failed to load model '{MODEL_PATH}': {e}"
    def _boot_error_ui():
        with gr.Blocks(title="VILA 1.5 3B – Error") as demo:
            gr.Markdown("### ❌ Model failed to load")
            gr.Markdown(ERR)
        demo.launch()
    _boot_error_ui()
    raise

# Fallback chat template if missing
if getattr(tokenizer, "chat_template", None) is None:
    tokenizer.chat_template = (
        "{% for message in messages %}{{ message['role'] | upper }}: "
        "{{ message['content'] }}\n{% endfor %}ASSISTANT:"
    )

# ===============================
# Inference
# ===============================
from PIL import Image as _PILImage

def vila_infer(image, prompt):
    if image is None:
        return "Please upload an image."
    if not prompt or not str(prompt).strip():
        prompt = "Please describe the image."

    pil = _PILImage.fromarray(image).convert("RGB")

    try:
        out = model.generate_content(
            prompt=[{
                "from": "human",
                "value": [
                    {"type": "image", "value": pil},
                    {"type": "text", "value": prompt}
                ]
            }],
            generation_config=None  # default decoding
        )
        return str(out).strip()
    except Exception as e:
        return f"❌ Inference error: {e}"

# ===============================
# UI
# ===============================
with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
    gr.Markdown("## πŸ–ΌοΈ VILA-1.5-3B β€” Image Description Demo")
    gr.Markdown("Upload an image and press **Run**.")

    with gr.Row():
        img = gr.Image(type="numpy", label="Image", height=320)
        prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)

    run_btn = gr.Button("Run")
    out = gr.Textbox(label="Output", lines=10)

    run_btn.click(vila_infer, [img, prompt], out)

demo.launch()