Spaces:
Runtime error
Runtime error
File size: 5,645 Bytes
c6110b4 279f604 a709033 89f55ad 24c4395 89f55ad 346da8b 89f55ad b2016f5 c6110b4 b2016f5 c6110b4 b2016f5 24c4395 346da8b b2016f5 89f55ad a709033 b2016f5 a709033 89f55ad b2016f5 a709033 b2016f5 a709033 346da8b c6110b4 89f55ad f3b369a a709033 b2016f5 346da8b b2016f5 24c4395 a709033 c6110b4 24c4395 b2016f5 24c4395 89f55ad c6110b4 89f55ad f3b369a 89f55ad 24c4395 a709033 c6110b4 24c4395 f3b369a 24c4395 c6110b4 f3b369a c6110b4 24c4395 89f55ad 346da8b 89f55ad c6110b4 346da8b c6110b4 24c4395 c6110b4 24c4395 346da8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# app.py
import os
import sys
import types
import importlib.machinery
from PIL import Image
import gradio as gr
# ===============================
# Helper to create package-like dummy modules
# ===============================
def _mk_pkg(name: str):
m = types.ModuleType(name)
spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
spec.submodule_search_locations = []
m.__spec__ = spec
m.__path__ = []
return m
# ===============================
# Disable GPU-only/optional paths
# ===============================
os.environ.setdefault("FLASH_ATTENTION", "0")
os.environ.setdefault("XFORMERS_DISABLED", "1")
os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
os.environ.setdefault("DISABLE_TRITON", "1") # avoid triton kernels
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # uncomment to force CPU
# ===============================
# flash_attn stubs (package + submodules)
# ===============================
flash_attn_pkg = _mk_pkg("flash_attn")
flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.flash_attn_interface", loader=None
)
flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.bert_padding", loader=None
)
def _dummy_func(*args, **kwargs):
raise RuntimeError("flash_attn is not available in this environment.")
flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
flash_attn_bert_padding.pad_input = _dummy_func
flash_attn_bert_padding.unpad_input = _dummy_func
sys.modules["flash_attn"] = flash_attn_pkg
sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
# ===============================
# ps3 stub (optional vision tower)
# ===============================
ps3_pkg = _mk_pkg("ps3")
class _PS3Config: pass
class _PS3VisionConfig: pass
class _PS3ImageProcessor: pass
class _PS3VisionModel: pass
ps3_pkg.PS3Config = _PS3Config
ps3_pkg.PS3VisionConfig = _PS3VisionConfig
ps3_pkg.PS3ImageProcessor = _PS3ImageProcessor
ps3_pkg.PS3VisionModel = _PS3VisionModel
sys.modules["ps3"] = ps3_pkg
# ===============================
# Quantization stubs to avoid Triton/Torch custom kernels
# VILA sometimes imports:
# - from .FloatPointQuantizeTriton import *
# - from FloatPointQuantizeTriton import *
# - from FloatPointQuantizeTorch import *
# Provide both names (absolute and package-qualified) with no-op funcs.
# ===============================
def _mk_fpq_module(mod_name: str):
mod = types.ModuleType(mod_name)
# Provide the APIs qfunction expects
def _id(x, *a, **k): return x
mod.block_cut = _id
mod.block_quant = _id
mod.block_reshape = _id
return mod
# Absolute names
sys.modules["FloatPointQuantizeTorch"] = _mk_fpq_module("FloatPointQuantizeTorch")
sys.modules["FloatPointQuantizeTriton"] = _mk_fpq_module("FloatPointQuantizeTriton")
# Package-qualified under llava.model
sys.modules["llava.model.FloatPointQuantizeTorch"] = sys.modules["FloatPointQuantizeTorch"]
sys.modules["llava.model.FloatPointQuantizeTriton"] = sys.modules["FloatPointQuantizeTriton"]
# ===============================
# Load VILA
# ===============================
from llava.model.builder import load_pretrained_model
from llava.constants import DEFAULT_IMAGE_TOKEN
MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
try:
tokenizer, model, image_processor, context_len = load_pretrained_model(
MODEL_PATH, model_name="", model_base=None
)
except Exception as e:
ERR = f"Failed to load model '{MODEL_PATH}': {e}"
def _boot_error_ui():
with gr.Blocks(title="VILA 1.5 3B β Error") as demo:
gr.Markdown("### β Model failed to load")
gr.Markdown(ERR)
demo.launch()
_boot_error_ui()
raise
# Fallback chat template if missing
if getattr(tokenizer, "chat_template", None) is None:
tokenizer.chat_template = (
"{% for message in messages %}{{ message['role'] | upper }}: "
"{{ message['content'] }}\n{% endfor %}ASSISTANT:"
)
# ===============================
# Inference
# ===============================
from PIL import Image as _PILImage
def vila_infer(image, prompt):
if image is None:
return "Please upload an image."
if not prompt or not str(prompt).strip():
prompt = "Please describe the image."
pil = _PILImage.fromarray(image).convert("RGB")
try:
out = model.generate_content(
prompt=[{
"from": "human",
"value": [
{"type": "image", "value": pil},
{"type": "text", "value": prompt}
]
}],
generation_config=None # default decoding
)
return str(out).strip()
except Exception as e:
return f"β Inference error: {e}"
# ===============================
# UI
# ===============================
with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
gr.Markdown("## πΌοΈ VILA-1.5-3B β Image Description Demo")
gr.Markdown("Upload an image and press **Run**.")
with gr.Row():
img = gr.Image(type="numpy", label="Image", height=320)
prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
run_btn = gr.Button("Run")
out = gr.Textbox(label="Output", lines=10)
run_btn.click(vila_infer, [img, prompt], out)
demo.launch()
|