DivyanshHF commited on
Commit
346da8b
Β·
verified Β·
1 Parent(s): c6110b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
  import gradio as gr
8
 
9
  # ===============================
10
- # Make dummy packages for flash_attn and ps3 (CPU-friendly import stubs)
11
  # ===============================
12
  def _mk_pkg(name: str):
13
  m = types.ModuleType(name)
@@ -17,7 +17,18 @@ def _mk_pkg(name: str):
17
  m.__path__ = []
18
  return m
19
 
20
- # --- flash_attn package + submodules ---
 
 
 
 
 
 
 
 
 
 
 
21
  flash_attn_pkg = _mk_pkg("flash_attn")
22
 
23
  flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
@@ -31,7 +42,6 @@ flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
31
  )
32
 
33
  def _dummy_func(*args, **kwargs):
34
- # Should never be called on CPU; if it is, fail loudly so we notice.
35
  raise RuntimeError("flash_attn is not available in this environment.")
36
 
37
  flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
@@ -43,7 +53,9 @@ sys.modules["flash_attn"] = flash_attn_pkg
43
  sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
44
  sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
45
 
46
- # --- ps3 package stub ---
 
 
47
  ps3_pkg = _mk_pkg("ps3")
48
  class _PS3Config: pass
49
  class _PS3VisionConfig: pass
@@ -56,16 +68,20 @@ ps3_pkg.PS3VisionModel = _PS3VisionModel
56
  sys.modules["ps3"] = ps3_pkg
57
 
58
  # ===============================
59
- # Runtime env (CPU-safe defaults)
 
 
60
  # ===============================
61
- os.environ.setdefault("FLASH_ATTENTION", "0")
62
- os.environ.setdefault("XFORMERS_DISABLED", "1")
63
- os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
64
- # Uncomment to force CPU even if a GPU is present on the Space
65
- # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
 
 
66
 
67
  # ===============================
68
- # VILA imports & model load
69
  # ===============================
70
  from llava.model.builder import load_pretrained_model
71
  from llava.constants import DEFAULT_IMAGE_TOKEN
@@ -77,7 +93,6 @@ try:
77
  MODEL_PATH, model_name="", model_base=None
78
  )
79
  except Exception as e:
80
- # Surface a friendly error on the UI instead of crashing the Space
81
  ERR = f"Failed to load model '{MODEL_PATH}': {e}"
82
  def _boot_error_ui():
83
  with gr.Blocks(title="VILA 1.5 3B – Error") as demo:
@@ -105,7 +120,6 @@ def vila_infer(image, prompt):
105
 
106
  pil = Image.fromarray(image).convert("RGB")
107
 
108
- # Minimal multimodal conversation: image + text
109
  try:
110
  out = model.generate_content(
111
  prompt=[{
@@ -115,18 +129,18 @@ def vila_infer(image, prompt):
115
  {"type": "text", "value": prompt}
116
  ]
117
  }],
118
- generation_config=None # use model defaults
119
  )
120
  return str(out).strip()
121
  except Exception as e:
122
  return f"❌ Inference error: {e}"
123
 
124
  # ===============================
125
- # Gradio UI
126
  # ===============================
127
  with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
128
  gr.Markdown("## πŸ–ΌοΈ VILA-1.5-3B β€” Image Description Demo")
129
- gr.Markdown("Upload an image and press **Run**. Leave the prompt as default for simple captioning.")
130
 
131
  with gr.Row():
132
  img = gr.Image(type="numpy", label="Image", height=320)
@@ -137,4 +151,4 @@ with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
137
 
138
  run_btn.click(vila_infer, [img, prompt], out)
139
 
140
- demo.launch()
 
7
  import gradio as gr
8
 
9
  # ===============================
10
+ # Helper to create package-like dummy modules
11
  # ===============================
12
  def _mk_pkg(name: str):
13
  m = types.ModuleType(name)
 
17
  m.__path__ = []
18
  return m
19
 
20
+ # ===============================
21
+ # Disable GPU-only/optional paths
22
+ # ===============================
23
+ os.environ.setdefault("FLASH_ATTENTION", "0")
24
+ os.environ.setdefault("XFORMERS_DISABLED", "1")
25
+ os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
26
+ os.environ.setdefault("DISABLE_TRITON", "1") # avoid triton kernels
27
+ # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # uncomment to force CPU
28
+
29
+ # ===============================
30
+ # flash_attn stubs (package + submodules)
31
+ # ===============================
32
  flash_attn_pkg = _mk_pkg("flash_attn")
33
 
34
  flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
 
42
  )
43
 
44
  def _dummy_func(*args, **kwargs):
 
45
  raise RuntimeError("flash_attn is not available in this environment.")
46
 
47
  flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
 
53
  sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
54
  sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
55
 
56
+ # ===============================
57
+ # ps3 stub (optional vision tower)
58
+ # ===============================
59
  ps3_pkg = _mk_pkg("ps3")
60
  class _PS3Config: pass
61
  class _PS3VisionConfig: pass
 
68
  sys.modules["ps3"] = ps3_pkg
69
 
70
  # ===============================
71
+ # Quantization stub to avoid Triton path
72
+ # VILA falls back to "from FloatPointQuantizeTorch import *" if Triton import fails.
73
+ # Provide a tiny no-op module so imports succeed.
74
  # ===============================
75
+ fpqt = types.ModuleType("FloatPointQuantizeTorch")
76
+ def _id(x, *a, **k): return x # identity
77
+ # names used by llava.model.qfunction
78
+ fpqt.block_cut = _id
79
+ fpqt.block_quant = _id
80
+ fpqt.block_reshape = _id
81
+ sys.modules["FloatPointQuantizeTorch"] = fpqt
82
 
83
  # ===============================
84
+ # Load VILA
85
  # ===============================
86
  from llava.model.builder import load_pretrained_model
87
  from llava.constants import DEFAULT_IMAGE_TOKEN
 
93
  MODEL_PATH, model_name="", model_base=None
94
  )
95
  except Exception as e:
 
96
  ERR = f"Failed to load model '{MODEL_PATH}': {e}"
97
  def _boot_error_ui():
98
  with gr.Blocks(title="VILA 1.5 3B – Error") as demo:
 
120
 
121
  pil = Image.fromarray(image).convert("RGB")
122
 
 
123
  try:
124
  out = model.generate_content(
125
  prompt=[{
 
129
  {"type": "text", "value": prompt}
130
  ]
131
  }],
132
+ generation_config=None
133
  )
134
  return str(out).strip()
135
  except Exception as e:
136
  return f"❌ Inference error: {e}"
137
 
138
  # ===============================
139
+ # UI
140
  # ===============================
141
  with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
142
  gr.Markdown("## πŸ–ΌοΈ VILA-1.5-3B β€” Image Description Demo")
143
+ gr.Markdown("Upload an image and press **Run**.")
144
 
145
  with gr.Row():
146
  img = gr.Image(type="numpy", label="Image", height=320)
 
151
 
152
  run_btn.click(vila_infer, [img, prompt], out)
153
 
154
+ demo.launch()