DivyanshHF commited on
Commit
a709033
·
verified ·
1 Parent(s): d3ff5e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -30
app.py CHANGED
@@ -1,62 +1,72 @@
1
  import os
2
-
3
- # ===== Disable GPU-specific optional deps for Hugging Face Spaces =====
4
- os.environ["FLASH_ATTENTION"] = "0"
5
- os.environ["DISABLE_FLASH_ATTN"] = "1"
6
- os.environ["XFORMERS_DISABLED"] = "1"
7
- os.environ["ACCELERATE_USE_DEVICE_MAP"] = "0"
8
-
9
- # Optional: force CPU if GPU not available
10
- # os.environ["CUDA_VISIBLE_DEVICES"] = ""
11
-
12
  import gradio as gr
13
  from PIL import Image
14
 
15
- # ---- VILA imports ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from llava.model.builder import load_pretrained_model
17
  from llava.constants import DEFAULT_IMAGE_TOKEN
18
 
19
- # === Load VILA 1.5-3B ===
20
  MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
 
21
  tokenizer, model, image_processor, context_len = load_pretrained_model(
22
  MODEL_PATH, model_name="", model_base=None
23
  )
24
 
25
- # === Fallback chat template (in case checkpoint doesn't have one) ===
26
  if getattr(tokenizer, "chat_template", None) is None:
27
  tokenizer.chat_template = (
28
  "{% for message in messages %}{{ message['role'] | upper }}: "
29
  "{{ message['content'] }}\n{% endfor %}ASSISTANT:"
30
  )
31
 
32
- # === Inference function ===
33
  def vila_infer(image, prompt, max_new_tokens, temperature):
34
  if image is None:
35
- return "Please upload an image."
36
  if not prompt.strip():
37
  prompt = "Please describe the image."
38
 
39
  pil = Image.fromarray(image).convert("RGB")
40
 
41
- # Prepare multimodal input for VILA
42
- conversation = [{
43
- "from": "human",
44
- "value": [
45
- {"type": "image", "value": pil},
46
- {"type": "text", "value": prompt}
47
- ]
48
- }]
49
-
50
- # Generate output
51
  out = model.generate_content(
52
- prompt=conversation,
53
- generation_config=None
 
 
 
 
 
 
54
  )
 
55
  return str(out)
56
 
57
- # === Gradio UI ===
58
- with gr.Blocks(title="VILA 1.5 3B Demo") as demo:
59
- gr.Markdown("## 🖼️ VILA-1.5-3B — Image Understanding Demo\nUpload an image and ask a question.")
60
 
61
  with gr.Row():
62
  img = gr.Image(type="numpy", label="Image", height=320)
 
1
  import os
2
+ import sys
3
+ import types
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
  from PIL import Image
6
 
7
+ # ======================
8
+ # Disable FlashAttention
9
+ # ======================
10
+ sys.modules["flash_attn"] = types.ModuleType("flash_attn")
11
+ sys.modules["flash_attn.flash_attn_interface"] = types.ModuleType("flash_attn.flash_attn_interface")
12
+
13
+ def _dummy_func(*args, **kwargs):
14
+ raise RuntimeError("FlashAttention is not available in this environment.")
15
+
16
+ sys.modules["flash_attn.flash_attn_interface"].flash_attn_unpadded_qkvpacked_func = _dummy_func
17
+ sys.modules["flash_attn.flash_attn_interface"].flash_attn_varlen_qkvpacked_func = _dummy_func
18
+
19
+ # ======================
20
+ # CPU-only settings
21
+ # ======================
22
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
23
+ os.environ.setdefault("FLASH_ATTENTION", "0")
24
+ os.environ.setdefault("XFORMERS_DISABLED", "1")
25
+ os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
26
+
27
+ # ======================
28
+ # VILA imports
29
+ # ======================
30
  from llava.model.builder import load_pretrained_model
31
  from llava.constants import DEFAULT_IMAGE_TOKEN
32
 
 
33
  MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
34
+
35
  tokenizer, model, image_processor, context_len = load_pretrained_model(
36
  MODEL_PATH, model_name="", model_base=None
37
  )
38
 
39
+ # Add fallback chat template if missing
40
  if getattr(tokenizer, "chat_template", None) is None:
41
  tokenizer.chat_template = (
42
  "{% for message in messages %}{{ message['role'] | upper }}: "
43
  "{{ message['content'] }}\n{% endfor %}ASSISTANT:"
44
  )
45
 
 
46
  def vila_infer(image, prompt, max_new_tokens, temperature):
47
  if image is None:
48
+ return "Please upload an image."
49
  if not prompt.strip():
50
  prompt = "Please describe the image."
51
 
52
  pil = Image.fromarray(image).convert("RGB")
53
 
54
+ # Minimal conversation: image + prompt
 
 
 
 
 
 
 
 
 
55
  out = model.generate_content(
56
+ prompt=[{
57
+ "from": "human",
58
+ "value": [
59
+ {"type": "image", "value": pil},
60
+ {"type": "text", "value": prompt}
61
+ ]
62
+ }],
63
+ generation_config={"max_new_tokens": max_new_tokens, "temperature": temperature}
64
  )
65
+
66
  return str(out)
67
 
68
+ with gr.Blocks(title="VILA 1.5 3B (CPU, HF Space)") as demo:
69
+ gr.Markdown("## 🖼️ VILA-1.5-3B — Image Captioning\nUpload an image and get a description.")
 
70
 
71
  with gr.Row():
72
  img = gr.Image(type="numpy", label="Image", height=320)