VisionLLM / app.py
DivyanshHF's picture
Update app.py
346da8b verified
raw
history blame
5.07 kB
# app.py
import os
import sys
import types
import importlib.machinery
from PIL import Image
import gradio as gr
# ===============================
# Helper to create package-like dummy modules
# ===============================
def _mk_pkg(name: str):
m = types.ModuleType(name)
spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
spec.submodule_search_locations = []
m.__spec__ = spec
m.__path__ = []
return m
# ===============================
# Disable GPU-only/optional paths
# ===============================
os.environ.setdefault("FLASH_ATTENTION", "0")
os.environ.setdefault("XFORMERS_DISABLED", "1")
os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
os.environ.setdefault("DISABLE_TRITON", "1") # avoid triton kernels
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # uncomment to force CPU
# ===============================
# flash_attn stubs (package + submodules)
# ===============================
flash_attn_pkg = _mk_pkg("flash_attn")
flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.flash_attn_interface", loader=None
)
flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.bert_padding", loader=None
)
def _dummy_func(*args, **kwargs):
raise RuntimeError("flash_attn is not available in this environment.")
flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
flash_attn_bert_padding.pad_input = _dummy_func
flash_attn_bert_padding.unpad_input = _dummy_func
sys.modules["flash_attn"] = flash_attn_pkg
sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
# ===============================
# ps3 stub (optional vision tower)
# ===============================
ps3_pkg = _mk_pkg("ps3")
class _PS3Config: pass
class _PS3VisionConfig: pass
class _PS3ImageProcessor: pass
class _PS3VisionModel: pass
ps3_pkg.PS3Config = _PS3Config
ps3_pkg.PS3VisionConfig = _PS3VisionConfig
ps3_pkg.PS3ImageProcessor = _PS3ImageProcessor
ps3_pkg.PS3VisionModel = _PS3VisionModel
sys.modules["ps3"] = ps3_pkg
# ===============================
# Quantization stub to avoid Triton path
# VILA falls back to "from FloatPointQuantizeTorch import *" if Triton import fails.
# Provide a tiny no-op module so imports succeed.
# ===============================
fpqt = types.ModuleType("FloatPointQuantizeTorch")
def _id(x, *a, **k): return x # identity
# names used by llava.model.qfunction
fpqt.block_cut = _id
fpqt.block_quant = _id
fpqt.block_reshape = _id
sys.modules["FloatPointQuantizeTorch"] = fpqt
# ===============================
# Load VILA
# ===============================
from llava.model.builder import load_pretrained_model
from llava.constants import DEFAULT_IMAGE_TOKEN
MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
try:
tokenizer, model, image_processor, context_len = load_pretrained_model(
MODEL_PATH, model_name="", model_base=None
)
except Exception as e:
ERR = f"Failed to load model '{MODEL_PATH}': {e}"
def _boot_error_ui():
with gr.Blocks(title="VILA 1.5 3B – Error") as demo:
gr.Markdown("### ❌ Model failed to load")
gr.Markdown(ERR)
demo.launch()
_boot_error_ui()
raise
# Fallback chat template if missing
if getattr(tokenizer, "chat_template", None) is None:
tokenizer.chat_template = (
"{% for message in messages %}{{ message['role'] | upper }}: "
"{{ message['content'] }}\n{% endfor %}ASSISTANT:"
)
# ===============================
# Inference
# ===============================
def vila_infer(image, prompt):
if image is None:
return "Please upload an image."
if not prompt or not str(prompt).strip():
prompt = "Please describe the image."
pil = Image.fromarray(image).convert("RGB")
try:
out = model.generate_content(
prompt=[{
"from": "human",
"value": [
{"type": "image", "value": pil},
{"type": "text", "value": prompt}
]
}],
generation_config=None
)
return str(out).strip()
except Exception as e:
return f"❌ Inference error: {e}"
# ===============================
# UI
# ===============================
with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
gr.Markdown("## πŸ–ΌοΈ VILA-1.5-3B β€” Image Description Demo")
gr.Markdown("Upload an image and press **Run**.")
with gr.Row():
img = gr.Image(type="numpy", label="Image", height=320)
prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
run_btn = gr.Button("Run")
out = gr.Textbox(label="Output", lines=10)
run_btn.click(vila_infer, [img, prompt], out)
demo.launch()