File size: 3,871 Bytes
279f604
a709033
 
89f55ad
24c4395
89f55ad
 
 
b2016f5
89f55ad
b2016f5
 
 
 
 
 
 
 
24c4395
b2016f5
 
 
 
 
 
89f55ad
 
a709033
b2016f5
 
 
 
 
 
a709033
b2016f5
89f55ad
 
b2016f5
 
 
 
 
a709033
b2016f5
 
 
 
a709033
89f55ad
b2016f5
89f55ad
a709033
 
 
b2016f5
 
a709033
b2016f5
 
 
24c4395
 
 
 
a709033
24c4395
 
 
 
b2016f5
24c4395
 
 
 
 
 
89f55ad
 
 
 
24c4395
a709033
24c4395
 
 
 
 
 
a709033
 
 
 
 
 
 
89f55ad
24c4395
 
 
89f55ad
 
 
 
 
24c4395
 
 
 
 
89f55ad
24c4395
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import sys
import types
import importlib.machinery
from PIL import Image
import gradio as gr

# ===============================
# Make a PACKAGE-like dummy flash_attn
# ===============================
def _mk_pkg(name: str):
    m = types.ModuleType(name)
    # Mark as a package: give it a spec with submodule locations and a __path__
    spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
    spec.submodule_search_locations = []  # important: tells importlib it's a package
    m.__spec__ = spec
    m.__path__ = []  # also marks as package
    return m

# Root package
flash_attn_pkg = _mk_pkg("flash_attn")

# Submodule: flash_attn.flash_attn_interface
flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
    "flash_attn.flash_attn_interface", loader=None
)

# Submodule: flash_attn.bert_padding
flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
    "flash_attn.bert_padding", loader=None
)

def _dummy_func(*args, **kwargs):
    # Should never be called on CPU; if it is, let’s fail loudly
    raise RuntimeError("flash_attn is not available in this environment.")

# Functions some imports expect to exist:
flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
flash_attn_bert_padding.pad_input = _dummy_func
flash_attn_bert_padding.unpad_input = _dummy_func

# Register modules
sys.modules["flash_attn"] = flash_attn_pkg
sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding

# ===============================
# Runtime env (CPU-friendly)
# ===============================
os.environ.setdefault("FLASH_ATTENTION", "0")
os.environ.setdefault("XFORMERS_DISABLED", "1")
os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
# Uncomment to force CPU even if a GPU is present:
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")

# ===============================
# VILA imports & load
# ===============================
from llava.model.builder import load_pretrained_model
from llava.constants import DEFAULT_IMAGE_TOKEN

MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    MODEL_PATH, model_name="", model_base=None
)

# Fallback chat template if missing
if getattr(tokenizer, "chat_template", None) is None:
    tokenizer.chat_template = (
        "{% for message in messages %}{{ message['role'] | upper }}: "
        "{{ message['content'] }}\n{% endfor %}ASSISTANT:"
    )

# ===============================
# Inference function
# ===============================
def vila_infer(image, prompt):
    if image is None:
        return "Please upload an image."
    if not prompt.strip():
        prompt = "Please describe the image."

    pil = Image.fromarray(image).convert("RGB")

    out = model.generate_content(
        prompt=[{
            "from": "human",
            "value": [
                {"type": "image", "value": pil},
                {"type": "text", "value": prompt}
            ]
        }],
        generation_config=None
    )
    return str(out)

# ===============================
# Gradio UI
# ===============================
with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
    gr.Markdown("## 🖼️ VILA-1.5-3B Image Description Demo\nUpload an image and get a description.")
    with gr.Row():
        img = gr.Image(type="numpy", label="Image", height=320)
        prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
    btn = gr.Button("Run")
    out = gr.Textbox(label="Output", lines=8)
    btn.click(vila_infer, [img, prompt], out)

demo.launch()