Spaces:
Runtime error
Runtime error
| # app.py | |
| import os | |
| import sys | |
| import types | |
| import importlib.machinery | |
| from PIL import Image | |
| import gradio as gr | |
| # =============================== | |
| # Helper to create package-like dummy modules | |
| # =============================== | |
| def _mk_pkg(name: str): | |
| m = types.ModuleType(name) | |
| spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True) | |
| spec.submodule_search_locations = [] | |
| m.__spec__ = spec | |
| m.__path__ = [] | |
| return m | |
| # =============================== | |
| # Disable GPU-only/optional paths | |
| # =============================== | |
| os.environ.setdefault("FLASH_ATTENTION", "0") | |
| os.environ.setdefault("XFORMERS_DISABLED", "1") | |
| os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0") | |
| os.environ.setdefault("DISABLE_TRITON", "1") # avoid triton kernels | |
| # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # uncomment to force CPU | |
| # =============================== | |
| # flash_attn stubs (package + submodules) | |
| # =============================== | |
| flash_attn_pkg = _mk_pkg("flash_attn") | |
| flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface") | |
| flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec( | |
| "flash_attn.flash_attn_interface", loader=None | |
| ) | |
| flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding") | |
| flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec( | |
| "flash_attn.bert_padding", loader=None | |
| ) | |
| def _dummy_func(*args, **kwargs): | |
| raise RuntimeError("flash_attn is not available in this environment.") | |
| flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func | |
| flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func | |
| flash_attn_bert_padding.pad_input = _dummy_func | |
| flash_attn_bert_padding.unpad_input = _dummy_func | |
| sys.modules["flash_attn"] = flash_attn_pkg | |
| sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface | |
| sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding | |
| # =============================== | |
| # ps3 stub (optional vision tower) | |
| # =============================== | |
| ps3_pkg = _mk_pkg("ps3") | |
| class _PS3Config: pass | |
| class _PS3VisionConfig: pass | |
| class _PS3ImageProcessor: pass | |
| class _PS3VisionModel: pass | |
| ps3_pkg.PS3Config = _PS3Config | |
| ps3_pkg.PS3VisionConfig = _PS3VisionConfig | |
| ps3_pkg.PS3ImageProcessor = _PS3ImageProcessor | |
| ps3_pkg.PS3VisionModel = _PS3VisionModel | |
| sys.modules["ps3"] = ps3_pkg | |
| # =============================== | |
| # Quantization stubs to avoid Triton/Torch custom kernels | |
| # VILA sometimes imports: | |
| # - from .FloatPointQuantizeTriton import * | |
| # - from FloatPointQuantizeTriton import * | |
| # - from FloatPointQuantizeTorch import * | |
| # Provide both names (absolute and package-qualified) with no-op funcs. | |
| # =============================== | |
| def _mk_fpq_module(mod_name: str): | |
| mod = types.ModuleType(mod_name) | |
| # Provide the APIs qfunction expects | |
| def _id(x, *a, **k): return x | |
| mod.block_cut = _id | |
| mod.block_quant = _id | |
| mod.block_reshape = _id | |
| return mod | |
| # Absolute names | |
| sys.modules["FloatPointQuantizeTorch"] = _mk_fpq_module("FloatPointQuantizeTorch") | |
| sys.modules["FloatPointQuantizeTriton"] = _mk_fpq_module("FloatPointQuantizeTriton") | |
| # Package-qualified under llava.model | |
| sys.modules["llava.model.FloatPointQuantizeTorch"] = sys.modules["FloatPointQuantizeTorch"] | |
| sys.modules["llava.model.FloatPointQuantizeTriton"] = sys.modules["FloatPointQuantizeTriton"] | |
| # =============================== | |
| # Load VILA | |
| # =============================== | |
| from llava.model.builder import load_pretrained_model | |
| from llava.constants import DEFAULT_IMAGE_TOKEN | |
| MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b" | |
| try: | |
| tokenizer, model, image_processor, context_len = load_pretrained_model( | |
| MODEL_PATH, model_name="", model_base=None | |
| ) | |
| except Exception as e: | |
| ERR = f"Failed to load model '{MODEL_PATH}': {e}" | |
| def _boot_error_ui(): | |
| with gr.Blocks(title="VILA 1.5 3B β Error") as demo: | |
| gr.Markdown("### β Model failed to load") | |
| gr.Markdown(ERR) | |
| demo.launch() | |
| _boot_error_ui() | |
| raise | |
| # Fallback chat template if missing | |
| if getattr(tokenizer, "chat_template", None) is None: | |
| tokenizer.chat_template = ( | |
| "{% for message in messages %}{{ message['role'] | upper }}: " | |
| "{{ message['content'] }}\n{% endfor %}ASSISTANT:" | |
| ) | |
| # =============================== | |
| # Inference | |
| # =============================== | |
| from PIL import Image as _PILImage | |
| def vila_infer(image, prompt): | |
| if image is None: | |
| return "Please upload an image." | |
| if not prompt or not str(prompt).strip(): | |
| prompt = "Please describe the image." | |
| pil = _PILImage.fromarray(image).convert("RGB") | |
| try: | |
| out = model.generate_content( | |
| prompt=[{ | |
| "from": "human", | |
| "value": [ | |
| {"type": "image", "value": pil}, | |
| {"type": "text", "value": prompt} | |
| ] | |
| }], | |
| generation_config=None # default decoding | |
| ) | |
| return str(out).strip() | |
| except Exception as e: | |
| return f"β Inference error: {e}" | |
| # =============================== | |
| # UI | |
| # =============================== | |
| with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo: | |
| gr.Markdown("## πΌοΈ VILA-1.5-3B β Image Description Demo") | |
| gr.Markdown("Upload an image and press **Run**.") | |
| with gr.Row(): | |
| img = gr.Image(type="numpy", label="Image", height=320) | |
| prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2) | |
| run_btn = gr.Button("Run") | |
| out = gr.Textbox(label="Output", lines=10) | |
| run_btn.click(vila_infer, [img, prompt], out) | |
| demo.launch() | |