Spaces:
Runtime error
Runtime error
File size: 3,871 Bytes
279f604 a709033 89f55ad 24c4395 89f55ad b2016f5 89f55ad b2016f5 24c4395 b2016f5 89f55ad a709033 b2016f5 a709033 b2016f5 89f55ad b2016f5 a709033 b2016f5 a709033 89f55ad b2016f5 89f55ad a709033 b2016f5 a709033 b2016f5 24c4395 a709033 24c4395 b2016f5 24c4395 89f55ad 24c4395 a709033 24c4395 a709033 89f55ad 24c4395 89f55ad 24c4395 89f55ad 24c4395 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
import sys
import types
import importlib.machinery
from PIL import Image
import gradio as gr
# ===============================
# Make a PACKAGE-like dummy flash_attn
# ===============================
def _mk_pkg(name: str):
m = types.ModuleType(name)
# Mark as a package: give it a spec with submodule locations and a __path__
spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
spec.submodule_search_locations = [] # important: tells importlib it's a package
m.__spec__ = spec
m.__path__ = [] # also marks as package
return m
# Root package
flash_attn_pkg = _mk_pkg("flash_attn")
# Submodule: flash_attn.flash_attn_interface
flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.flash_attn_interface", loader=None
)
# Submodule: flash_attn.bert_padding
flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.bert_padding", loader=None
)
def _dummy_func(*args, **kwargs):
# Should never be called on CPU; if it is, let’s fail loudly
raise RuntimeError("flash_attn is not available in this environment.")
# Functions some imports expect to exist:
flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
flash_attn_bert_padding.pad_input = _dummy_func
flash_attn_bert_padding.unpad_input = _dummy_func
# Register modules
sys.modules["flash_attn"] = flash_attn_pkg
sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
# ===============================
# Runtime env (CPU-friendly)
# ===============================
os.environ.setdefault("FLASH_ATTENTION", "0")
os.environ.setdefault("XFORMERS_DISABLED", "1")
os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
# Uncomment to force CPU even if a GPU is present:
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
# ===============================
# VILA imports & load
# ===============================
from llava.model.builder import load_pretrained_model
from llava.constants import DEFAULT_IMAGE_TOKEN
MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
MODEL_PATH, model_name="", model_base=None
)
# Fallback chat template if missing
if getattr(tokenizer, "chat_template", None) is None:
tokenizer.chat_template = (
"{% for message in messages %}{{ message['role'] | upper }}: "
"{{ message['content'] }}\n{% endfor %}ASSISTANT:"
)
# ===============================
# Inference function
# ===============================
def vila_infer(image, prompt):
if image is None:
return "Please upload an image."
if not prompt.strip():
prompt = "Please describe the image."
pil = Image.fromarray(image).convert("RGB")
out = model.generate_content(
prompt=[{
"from": "human",
"value": [
{"type": "image", "value": pil},
{"type": "text", "value": prompt}
]
}],
generation_config=None
)
return str(out)
# ===============================
# Gradio UI
# ===============================
with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
gr.Markdown("## 🖼️ VILA-1.5-3B Image Description Demo\nUpload an image and get a description.")
with gr.Row():
img = gr.Image(type="numpy", label="Image", height=320)
prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
btn = gr.Button("Run")
out = gr.Textbox(label="Output", lines=8)
btn.click(vila_infer, [img, prompt], out)
demo.launch()
|