VisionLLM / app.py
DivyanshHF's picture
Update app.py
b2016f5 verified
raw
history blame
3.87 kB
import os
import sys
import types
import importlib.machinery
from PIL import Image
import gradio as gr
# ===============================
# Make a PACKAGE-like dummy flash_attn
# ===============================
def _mk_pkg(name: str):
m = types.ModuleType(name)
# Mark as a package: give it a spec with submodule locations and a __path__
spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
spec.submodule_search_locations = [] # important: tells importlib it's a package
m.__spec__ = spec
m.__path__ = [] # also marks as package
return m
# Root package
flash_attn_pkg = _mk_pkg("flash_attn")
# Submodule: flash_attn.flash_attn_interface
flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.flash_attn_interface", loader=None
)
# Submodule: flash_attn.bert_padding
flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.bert_padding", loader=None
)
def _dummy_func(*args, **kwargs):
# Should never be called on CPU; if it is, let’s fail loudly
raise RuntimeError("flash_attn is not available in this environment.")
# Functions some imports expect to exist:
flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
flash_attn_bert_padding.pad_input = _dummy_func
flash_attn_bert_padding.unpad_input = _dummy_func
# Register modules
sys.modules["flash_attn"] = flash_attn_pkg
sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
# ===============================
# Runtime env (CPU-friendly)
# ===============================
os.environ.setdefault("FLASH_ATTENTION", "0")
os.environ.setdefault("XFORMERS_DISABLED", "1")
os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
# Uncomment to force CPU even if a GPU is present:
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
# ===============================
# VILA imports & load
# ===============================
from llava.model.builder import load_pretrained_model
from llava.constants import DEFAULT_IMAGE_TOKEN
MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
MODEL_PATH, model_name="", model_base=None
)
# Fallback chat template if missing
if getattr(tokenizer, "chat_template", None) is None:
tokenizer.chat_template = (
"{% for message in messages %}{{ message['role'] | upper }}: "
"{{ message['content'] }}\n{% endfor %}ASSISTANT:"
)
# ===============================
# Inference function
# ===============================
def vila_infer(image, prompt):
if image is None:
return "Please upload an image."
if not prompt.strip():
prompt = "Please describe the image."
pil = Image.fromarray(image).convert("RGB")
out = model.generate_content(
prompt=[{
"from": "human",
"value": [
{"type": "image", "value": pil},
{"type": "text", "value": prompt}
]
}],
generation_config=None
)
return str(out)
# ===============================
# Gradio UI
# ===============================
with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
gr.Markdown("## 🖼️ VILA-1.5-3B Image Description Demo\nUpload an image and get a description.")
with gr.Row():
img = gr.Image(type="numpy", label="Image", height=320)
prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
btn = gr.Button("Run")
out = gr.Textbox(label="Output", lines=8)
btn.click(vila_infer, [img, prompt], out)
demo.launch()