DivyanshHF commited on
Commit
b2016f5
·
verified ·
1 Parent(s): 89f55ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -12
app.py CHANGED
@@ -6,43 +6,69 @@ from PIL import Image
6
  import gradio as gr
7
 
8
  # ===============================
9
- # Patch flash_attn for CPU runtime
10
  # ===============================
11
- dummy_flash_attn = types.ModuleType("flash_attn")
12
- dummy_flash_attn.__spec__ = importlib.machinery.ModuleSpec("flash_attn", loader=None)
 
 
 
 
 
 
13
 
14
- dummy_interface = types.ModuleType("flash_attn.flash_attn_interface")
15
- dummy_interface.__spec__ = importlib.machinery.ModuleSpec(
 
 
 
 
16
  "flash_attn.flash_attn_interface", loader=None
17
  )
18
 
 
 
 
 
 
 
19
  def _dummy_func(*args, **kwargs):
 
20
  raise RuntimeError("flash_attn is not available in this environment.")
21
 
22
- dummy_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
23
- dummy_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
 
 
 
24
 
25
- sys.modules["flash_attn"] = dummy_flash_attn
26
- sys.modules["flash_attn.flash_attn_interface"] = dummy_interface
 
 
27
 
28
  # ===============================
29
- # Hugging Face model setup
30
  # ===============================
31
  os.environ.setdefault("FLASH_ATTENTION", "0")
32
  os.environ.setdefault("XFORMERS_DISABLED", "1")
33
  os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
 
 
34
 
 
 
 
35
  from llava.model.builder import load_pretrained_model
36
  from llava.constants import DEFAULT_IMAGE_TOKEN
37
 
38
  MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
39
 
40
- # Load model + tokenizer + image processor
41
  tokenizer, model, image_processor, context_len = load_pretrained_model(
42
  MODEL_PATH, model_name="", model_base=None
43
  )
44
 
45
- # Add a fallback chat template
46
  if getattr(tokenizer, "chat_template", None) is None:
47
  tokenizer.chat_template = (
48
  "{% for message in messages %}{{ message['role'] | upper }}: "
 
6
  import gradio as gr
7
 
8
  # ===============================
9
+ # Make a PACKAGE-like dummy flash_attn
10
  # ===============================
11
+ def _mk_pkg(name: str):
12
+ m = types.ModuleType(name)
13
+ # Mark as a package: give it a spec with submodule locations and a __path__
14
+ spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
15
+ spec.submodule_search_locations = [] # important: tells importlib it's a package
16
+ m.__spec__ = spec
17
+ m.__path__ = [] # also marks as package
18
+ return m
19
 
20
+ # Root package
21
+ flash_attn_pkg = _mk_pkg("flash_attn")
22
+
23
+ # Submodule: flash_attn.flash_attn_interface
24
+ flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
25
+ flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
26
  "flash_attn.flash_attn_interface", loader=None
27
  )
28
 
29
+ # Submodule: flash_attn.bert_padding
30
+ flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
31
+ flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
32
+ "flash_attn.bert_padding", loader=None
33
+ )
34
+
35
  def _dummy_func(*args, **kwargs):
36
+ # Should never be called on CPU; if it is, let’s fail loudly
37
  raise RuntimeError("flash_attn is not available in this environment.")
38
 
39
+ # Functions some imports expect to exist:
40
+ flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
41
+ flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
42
+ flash_attn_bert_padding.pad_input = _dummy_func
43
+ flash_attn_bert_padding.unpad_input = _dummy_func
44
 
45
+ # Register modules
46
+ sys.modules["flash_attn"] = flash_attn_pkg
47
+ sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
48
+ sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
49
 
50
  # ===============================
51
+ # Runtime env (CPU-friendly)
52
  # ===============================
53
  os.environ.setdefault("FLASH_ATTENTION", "0")
54
  os.environ.setdefault("XFORMERS_DISABLED", "1")
55
  os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
56
+ # Uncomment to force CPU even if a GPU is present:
57
+ # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
58
 
59
+ # ===============================
60
+ # VILA imports & load
61
+ # ===============================
62
  from llava.model.builder import load_pretrained_model
63
  from llava.constants import DEFAULT_IMAGE_TOKEN
64
 
65
  MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
66
 
 
67
  tokenizer, model, image_processor, context_len = load_pretrained_model(
68
  MODEL_PATH, model_name="", model_base=None
69
  )
70
 
71
+ # Fallback chat template if missing
72
  if getattr(tokenizer, "chat_template", None) is None:
73
  tokenizer.chat_template = (
74
  "{% for message in messages %}{{ message['role'] | upper }}: "