DivyanshHF commited on
Commit
89f55ad
·
verified ·
1 Parent(s): a709033

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -31
app.py CHANGED
@@ -1,49 +1,58 @@
1
  import os
2
  import sys
3
  import types
4
- import gradio as gr
5
  from PIL import Image
 
 
 
 
 
 
 
6
 
7
- # ======================
8
- # Disable FlashAttention
9
- # ======================
10
- sys.modules["flash_attn"] = types.ModuleType("flash_attn")
11
- sys.modules["flash_attn.flash_attn_interface"] = types.ModuleType("flash_attn.flash_attn_interface")
12
 
13
  def _dummy_func(*args, **kwargs):
14
- raise RuntimeError("FlashAttention is not available in this environment.")
 
 
 
15
 
16
- sys.modules["flash_attn.flash_attn_interface"].flash_attn_unpadded_qkvpacked_func = _dummy_func
17
- sys.modules["flash_attn.flash_attn_interface"].flash_attn_varlen_qkvpacked_func = _dummy_func
18
 
19
- # ======================
20
- # CPU-only settings
21
- # ======================
22
- os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
23
  os.environ.setdefault("FLASH_ATTENTION", "0")
24
  os.environ.setdefault("XFORMERS_DISABLED", "1")
25
  os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
26
 
27
- # ======================
28
- # VILA imports
29
- # ======================
30
  from llava.model.builder import load_pretrained_model
31
  from llava.constants import DEFAULT_IMAGE_TOKEN
32
 
33
  MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
34
 
 
35
  tokenizer, model, image_processor, context_len = load_pretrained_model(
36
  MODEL_PATH, model_name="", model_base=None
37
  )
38
 
39
- # Add fallback chat template if missing
40
  if getattr(tokenizer, "chat_template", None) is None:
41
  tokenizer.chat_template = (
42
  "{% for message in messages %}{{ message['role'] | upper }}: "
43
  "{{ message['content'] }}\n{% endfor %}ASSISTANT:"
44
  )
45
 
46
- def vila_infer(image, prompt, max_new_tokens, temperature):
 
 
 
47
  if image is None:
48
  return "Please upload an image."
49
  if not prompt.strip():
@@ -51,7 +60,6 @@ def vila_infer(image, prompt, max_new_tokens, temperature):
51
 
52
  pil = Image.fromarray(image).convert("RGB")
53
 
54
- # Minimal conversation: image + prompt
55
  out = model.generate_content(
56
  prompt=[{
57
  "from": "human",
@@ -60,25 +68,20 @@ def vila_infer(image, prompt, max_new_tokens, temperature):
60
  {"type": "text", "value": prompt}
61
  ]
62
  }],
63
- generation_config={"max_new_tokens": max_new_tokens, "temperature": temperature}
64
  )
65
-
66
  return str(out)
67
 
68
- with gr.Blocks(title="VILA 1.5 3B (CPU, HF Space)") as demo:
69
- gr.Markdown("## 🖼️ VILA-1.5-3B — Image Captioning\nUpload an image and get a description.")
70
-
 
 
71
  with gr.Row():
72
  img = gr.Image(type="numpy", label="Image", height=320)
73
  prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
74
-
75
- with gr.Row():
76
- max_new = gr.Slider(16, 256, value=96, step=1, label="Max new tokens")
77
- temp = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
78
-
79
  btn = gr.Button("Run")
80
  out = gr.Textbox(label="Output", lines=8)
81
-
82
- btn.click(vila_infer, [img, prompt, max_new, temp], out)
83
 
84
  demo.launch()
 
1
  import os
2
  import sys
3
  import types
4
+ import importlib.machinery
5
  from PIL import Image
6
+ import gradio as gr
7
+
8
+ # ===============================
9
+ # Patch flash_attn for CPU runtime
10
+ # ===============================
11
+ dummy_flash_attn = types.ModuleType("flash_attn")
12
+ dummy_flash_attn.__spec__ = importlib.machinery.ModuleSpec("flash_attn", loader=None)
13
 
14
+ dummy_interface = types.ModuleType("flash_attn.flash_attn_interface")
15
+ dummy_interface.__spec__ = importlib.machinery.ModuleSpec(
16
+ "flash_attn.flash_attn_interface", loader=None
17
+ )
 
18
 
19
  def _dummy_func(*args, **kwargs):
20
+ raise RuntimeError("flash_attn is not available in this environment.")
21
+
22
+ dummy_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
23
+ dummy_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
24
 
25
+ sys.modules["flash_attn"] = dummy_flash_attn
26
+ sys.modules["flash_attn.flash_attn_interface"] = dummy_interface
27
 
28
+ # ===============================
29
+ # Hugging Face model setup
30
+ # ===============================
 
31
  os.environ.setdefault("FLASH_ATTENTION", "0")
32
  os.environ.setdefault("XFORMERS_DISABLED", "1")
33
  os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
34
 
 
 
 
35
  from llava.model.builder import load_pretrained_model
36
  from llava.constants import DEFAULT_IMAGE_TOKEN
37
 
38
  MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
39
 
40
+ # Load model + tokenizer + image processor
41
  tokenizer, model, image_processor, context_len = load_pretrained_model(
42
  MODEL_PATH, model_name="", model_base=None
43
  )
44
 
45
+ # Add a fallback chat template
46
  if getattr(tokenizer, "chat_template", None) is None:
47
  tokenizer.chat_template = (
48
  "{% for message in messages %}{{ message['role'] | upper }}: "
49
  "{{ message['content'] }}\n{% endfor %}ASSISTANT:"
50
  )
51
 
52
+ # ===============================
53
+ # Inference function
54
+ # ===============================
55
+ def vila_infer(image, prompt):
56
  if image is None:
57
  return "Please upload an image."
58
  if not prompt.strip():
 
60
 
61
  pil = Image.fromarray(image).convert("RGB")
62
 
 
63
  out = model.generate_content(
64
  prompt=[{
65
  "from": "human",
 
68
  {"type": "text", "value": prompt}
69
  ]
70
  }],
71
+ generation_config=None
72
  )
 
73
  return str(out)
74
 
75
+ # ===============================
76
+ # Gradio UI
77
+ # ===============================
78
+ with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
79
+ gr.Markdown("## 🖼️ VILA-1.5-3B Image Description Demo\nUpload an image and get a description.")
80
  with gr.Row():
81
  img = gr.Image(type="numpy", label="Image", height=320)
82
  prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
 
 
 
 
 
83
  btn = gr.Button("Run")
84
  out = gr.Textbox(label="Output", lines=8)
85
+ btn.click(vila_infer, [img, prompt], out)
 
86
 
87
  demo.launch()