# app.py import os import sys import types import importlib.machinery from PIL import Image import gradio as gr # =============================== # Helper to create package-like dummy modules # =============================== def _mk_pkg(name: str): m = types.ModuleType(name) spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True) spec.submodule_search_locations = [] m.__spec__ = spec m.__path__ = [] return m # =============================== # Disable GPU-only/optional paths # =============================== os.environ.setdefault("FLASH_ATTENTION", "0") os.environ.setdefault("XFORMERS_DISABLED", "1") os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0") os.environ.setdefault("DISABLE_TRITON", "1") # avoid triton kernels # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # uncomment to force CPU # =============================== # flash_attn stubs (package + submodules) # =============================== flash_attn_pkg = _mk_pkg("flash_attn") flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface") flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec( "flash_attn.flash_attn_interface", loader=None ) flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding") flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec( "flash_attn.bert_padding", loader=None ) def _dummy_func(*args, **kwargs): raise RuntimeError("flash_attn is not available in this environment.") flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func flash_attn_bert_padding.pad_input = _dummy_func flash_attn_bert_padding.unpad_input = _dummy_func sys.modules["flash_attn"] = flash_attn_pkg sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding # =============================== # ps3 stub (optional vision tower) # =============================== ps3_pkg = _mk_pkg("ps3") class _PS3Config: pass class _PS3VisionConfig: pass class _PS3ImageProcessor: pass class _PS3VisionModel: pass ps3_pkg.PS3Config = _PS3Config ps3_pkg.PS3VisionConfig = _PS3VisionConfig ps3_pkg.PS3ImageProcessor = _PS3ImageProcessor ps3_pkg.PS3VisionModel = _PS3VisionModel sys.modules["ps3"] = ps3_pkg # =============================== # Quantization stubs to avoid Triton/Torch custom kernels # VILA sometimes imports: # - from .FloatPointQuantizeTriton import * # - from FloatPointQuantizeTriton import * # - from FloatPointQuantizeTorch import * # Provide both names (absolute and package-qualified) with no-op funcs. # =============================== def _mk_fpq_module(mod_name: str): mod = types.ModuleType(mod_name) # Provide the APIs qfunction expects def _id(x, *a, **k): return x mod.block_cut = _id mod.block_quant = _id mod.block_reshape = _id return mod # Absolute names sys.modules["FloatPointQuantizeTorch"] = _mk_fpq_module("FloatPointQuantizeTorch") sys.modules["FloatPointQuantizeTriton"] = _mk_fpq_module("FloatPointQuantizeTriton") # Package-qualified under llava.model sys.modules["llava.model.FloatPointQuantizeTorch"] = sys.modules["FloatPointQuantizeTorch"] sys.modules["llava.model.FloatPointQuantizeTriton"] = sys.modules["FloatPointQuantizeTriton"] # =============================== # Load VILA # =============================== from llava.model.builder import load_pretrained_model from llava.constants import DEFAULT_IMAGE_TOKEN MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b" try: tokenizer, model, image_processor, context_len = load_pretrained_model( MODEL_PATH, model_name="", model_base=None ) except Exception as e: ERR = f"Failed to load model '{MODEL_PATH}': {e}" def _boot_error_ui(): with gr.Blocks(title="VILA 1.5 3B – Error") as demo: gr.Markdown("### ❌ Model failed to load") gr.Markdown(ERR) demo.launch() _boot_error_ui() raise # Fallback chat template if missing if getattr(tokenizer, "chat_template", None) is None: tokenizer.chat_template = ( "{% for message in messages %}{{ message['role'] | upper }}: " "{{ message['content'] }}\n{% endfor %}ASSISTANT:" ) # =============================== # Inference # =============================== from PIL import Image as _PILImage def vila_infer(image, prompt): if image is None: return "Please upload an image." if not prompt or not str(prompt).strip(): prompt = "Please describe the image." pil = _PILImage.fromarray(image).convert("RGB") try: out = model.generate_content( prompt=[{ "from": "human", "value": [ {"type": "image", "value": pil}, {"type": "text", "value": prompt} ] }], generation_config=None # default decoding ) return str(out).strip() except Exception as e: return f"❌ Inference error: {e}" # =============================== # UI # =============================== with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo: gr.Markdown("## 🖼️ VILA-1.5-3B — Image Description Demo") gr.Markdown("Upload an image and press **Run**.") with gr.Row(): img = gr.Image(type="numpy", label="Image", height=320) prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2) run_btn = gr.Button("Run") out = gr.Textbox(label="Output", lines=10) run_btn.click(vila_infer, [img, prompt], out) demo.launch()