DivyanshHF commited on
Commit
c6110b4
·
verified ·
1 Parent(s): 61a367a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -33
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import sys
3
  import types
@@ -6,67 +7,85 @@ from PIL import Image
6
  import gradio as gr
7
 
8
  # ===============================
9
- # Make a PACKAGE-like dummy flash_attn
10
  # ===============================
11
  def _mk_pkg(name: str):
12
  m = types.ModuleType(name)
13
- # Mark as a package: give it a spec with submodule locations and a __path__
14
  spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
15
- spec.submodule_search_locations = [] # important: tells importlib it's a package
16
  m.__spec__ = spec
17
- m.__path__ = [] # also marks as package
18
  return m
19
 
20
- # Root package
21
  flash_attn_pkg = _mk_pkg("flash_attn")
22
 
23
- # Submodule: flash_attn.flash_attn_interface
24
  flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
25
  flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
26
  "flash_attn.flash_attn_interface", loader=None
27
  )
28
 
29
- # Submodule: flash_attn.bert_padding
30
  flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
31
  flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
32
  "flash_attn.bert_padding", loader=None
33
  )
34
 
35
  def _dummy_func(*args, **kwargs):
36
- # Should never be called on CPU; if it is, let’s fail loudly
37
  raise RuntimeError("flash_attn is not available in this environment.")
38
 
39
- # Functions some imports expect to exist:
40
  flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
41
  flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
42
  flash_attn_bert_padding.pad_input = _dummy_func
43
  flash_attn_bert_padding.unpad_input = _dummy_func
44
 
45
- # Register modules
46
  sys.modules["flash_attn"] = flash_attn_pkg
47
  sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
48
  sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
49
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # ===============================
51
- # Runtime env (CPU-friendly)
52
  # ===============================
53
  os.environ.setdefault("FLASH_ATTENTION", "0")
54
  os.environ.setdefault("XFORMERS_DISABLED", "1")
55
  os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
56
- # Uncomment to force CPU even if a GPU is present:
57
  # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
58
 
59
  # ===============================
60
- # VILA imports & load
61
  # ===============================
62
  from llava.model.builder import load_pretrained_model
63
  from llava.constants import DEFAULT_IMAGE_TOKEN
64
 
65
  MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
66
 
67
- tokenizer, model, image_processor, context_len = load_pretrained_model(
68
- MODEL_PATH, model_name="", model_base=None
69
- )
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  # Fallback chat template if missing
72
  if getattr(tokenizer, "chat_template", None) is None:
@@ -76,38 +95,46 @@ if getattr(tokenizer, "chat_template", None) is None:
76
  )
77
 
78
  # ===============================
79
- # Inference function
80
  # ===============================
81
  def vila_infer(image, prompt):
82
  if image is None:
83
  return "Please upload an image."
84
- if not prompt.strip():
85
  prompt = "Please describe the image."
86
 
87
  pil = Image.fromarray(image).convert("RGB")
88
 
89
- out = model.generate_content(
90
- prompt=[{
91
- "from": "human",
92
- "value": [
93
- {"type": "image", "value": pil},
94
- {"type": "text", "value": prompt}
95
- ]
96
- }],
97
- generation_config=None
98
- )
99
- return str(out)
 
 
 
 
100
 
101
  # ===============================
102
  # Gradio UI
103
  # ===============================
104
  with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
105
- gr.Markdown("## 🖼️ VILA-1.5-3B Image Description Demo\nUpload an image and get a description.")
 
 
106
  with gr.Row():
107
  img = gr.Image(type="numpy", label="Image", height=320)
108
  prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
109
- btn = gr.Button("Run")
110
- out = gr.Textbox(label="Output", lines=8)
111
- btn.click(vila_infer, [img, prompt], out)
 
 
112
 
113
  demo.launch()
 
1
+ # app.py
2
  import os
3
  import sys
4
  import types
 
7
  import gradio as gr
8
 
9
  # ===============================
10
+ # Make dummy packages for flash_attn and ps3 (CPU-friendly import stubs)
11
  # ===============================
12
  def _mk_pkg(name: str):
13
  m = types.ModuleType(name)
 
14
  spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
15
+ spec.submodule_search_locations = []
16
  m.__spec__ = spec
17
+ m.__path__ = []
18
  return m
19
 
20
+ # --- flash_attn package + submodules ---
21
  flash_attn_pkg = _mk_pkg("flash_attn")
22
 
 
23
  flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
24
  flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
25
  "flash_attn.flash_attn_interface", loader=None
26
  )
27
 
 
28
  flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
29
  flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
30
  "flash_attn.bert_padding", loader=None
31
  )
32
 
33
  def _dummy_func(*args, **kwargs):
34
+ # Should never be called on CPU; if it is, fail loudly so we notice.
35
  raise RuntimeError("flash_attn is not available in this environment.")
36
 
 
37
  flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
38
  flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
39
  flash_attn_bert_padding.pad_input = _dummy_func
40
  flash_attn_bert_padding.unpad_input = _dummy_func
41
 
 
42
  sys.modules["flash_attn"] = flash_attn_pkg
43
  sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
44
  sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
45
 
46
+ # --- ps3 package stub ---
47
+ ps3_pkg = _mk_pkg("ps3")
48
+ class _PS3Config: pass
49
+ class _PS3VisionConfig: pass
50
+ class _PS3ImageProcessor: pass
51
+ class _PS3VisionModel: pass
52
+ ps3_pkg.PS3Config = _PS3Config
53
+ ps3_pkg.PS3VisionConfig = _PS3VisionConfig
54
+ ps3_pkg.PS3ImageProcessor = _PS3ImageProcessor
55
+ ps3_pkg.PS3VisionModel = _PS3VisionModel
56
+ sys.modules["ps3"] = ps3_pkg
57
+
58
  # ===============================
59
+ # Runtime env (CPU-safe defaults)
60
  # ===============================
61
  os.environ.setdefault("FLASH_ATTENTION", "0")
62
  os.environ.setdefault("XFORMERS_DISABLED", "1")
63
  os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
64
+ # Uncomment to force CPU even if a GPU is present on the Space
65
  # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
66
 
67
  # ===============================
68
+ # VILA imports & model load
69
  # ===============================
70
  from llava.model.builder import load_pretrained_model
71
  from llava.constants import DEFAULT_IMAGE_TOKEN
72
 
73
  MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
74
 
75
+ try:
76
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
77
+ MODEL_PATH, model_name="", model_base=None
78
+ )
79
+ except Exception as e:
80
+ # Surface a friendly error on the UI instead of crashing the Space
81
+ ERR = f"Failed to load model '{MODEL_PATH}': {e}"
82
+ def _boot_error_ui():
83
+ with gr.Blocks(title="VILA 1.5 3B – Error") as demo:
84
+ gr.Markdown("### ❌ Model failed to load")
85
+ gr.Markdown(ERR)
86
+ demo.launch()
87
+ _boot_error_ui()
88
+ raise
89
 
90
  # Fallback chat template if missing
91
  if getattr(tokenizer, "chat_template", None) is None:
 
95
  )
96
 
97
  # ===============================
98
+ # Inference
99
  # ===============================
100
  def vila_infer(image, prompt):
101
  if image is None:
102
  return "Please upload an image."
103
+ if not prompt or not str(prompt).strip():
104
  prompt = "Please describe the image."
105
 
106
  pil = Image.fromarray(image).convert("RGB")
107
 
108
+ # Minimal multimodal conversation: image + text
109
+ try:
110
+ out = model.generate_content(
111
+ prompt=[{
112
+ "from": "human",
113
+ "value": [
114
+ {"type": "image", "value": pil},
115
+ {"type": "text", "value": prompt}
116
+ ]
117
+ }],
118
+ generation_config=None # use model defaults
119
+ )
120
+ return str(out).strip()
121
+ except Exception as e:
122
+ return f"❌ Inference error: {e}"
123
 
124
  # ===============================
125
  # Gradio UI
126
  # ===============================
127
  with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
128
+ gr.Markdown("## 🖼️ VILA-1.5-3B Image Description Demo")
129
+ gr.Markdown("Upload an image and press **Run**. Leave the prompt as default for simple captioning.")
130
+
131
  with gr.Row():
132
  img = gr.Image(type="numpy", label="Image", height=320)
133
  prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
134
+
135
+ run_btn = gr.Button("Run")
136
+ out = gr.Textbox(label="Output", lines=10)
137
+
138
+ run_btn.click(vila_infer, [img, prompt], out)
139
 
140
  demo.launch()