DivyanshHF commited on
Commit
f3b369a
·
verified ·
1 Parent(s): a46c5a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -68,17 +68,28 @@ ps3_pkg.PS3VisionModel = _PS3VisionModel
68
  sys.modules["ps3"] = ps3_pkg
69
 
70
  # ===============================
71
- # Quantization stub to avoid Triton path
72
- # VILA falls back to "from FloatPointQuantizeTorch import *" if Triton import fails.
73
- # Provide a tiny no-op module so imports succeed.
74
- # ===============================
75
- fpqt = types.ModuleType("FloatPointQuantizeTorch")
76
- def _id(x, *a, **k): return x # identity
77
- # names used by llava.model.qfunction
78
- fpqt.block_cut = _id
79
- fpqt.block_quant = _id
80
- fpqt.block_reshape = _id
81
- sys.modules["FloatPointQuantizeTorch"] = fpqt
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # ===============================
84
  # Load VILA
@@ -112,13 +123,15 @@ if getattr(tokenizer, "chat_template", None) is None:
112
  # ===============================
113
  # Inference
114
  # ===============================
 
 
115
  def vila_infer(image, prompt):
116
  if image is None:
117
  return "Please upload an image."
118
  if not prompt or not str(prompt).strip():
119
  prompt = "Please describe the image."
120
 
121
- pil = Image.fromarray(image).convert("RGB")
122
 
123
  try:
124
  out = model.generate_content(
@@ -129,7 +142,7 @@ def vila_infer(image, prompt):
129
  {"type": "text", "value": prompt}
130
  ]
131
  }],
132
- generation_config=None
133
  )
134
  return str(out).strip()
135
  except Exception as e:
 
68
  sys.modules["ps3"] = ps3_pkg
69
 
70
  # ===============================
71
+ # Quantization stubs to avoid Triton/Torch custom kernels
72
+ # VILA sometimes imports:
73
+ # - from .FloatPointQuantizeTriton import *
74
+ # - from FloatPointQuantizeTriton import *
75
+ # - from FloatPointQuantizeTorch import *
76
+ # Provide both names (absolute and package-qualified) with no-op funcs.
77
+ # ===============================
78
+ def _mk_fpq_module(mod_name: str):
79
+ mod = types.ModuleType(mod_name)
80
+ # Provide the APIs qfunction expects
81
+ def _id(x, *a, **k): return x
82
+ mod.block_cut = _id
83
+ mod.block_quant = _id
84
+ mod.block_reshape = _id
85
+ return mod
86
+
87
+ # Absolute names
88
+ sys.modules["FloatPointQuantizeTorch"] = _mk_fpq_module("FloatPointQuantizeTorch")
89
+ sys.modules["FloatPointQuantizeTriton"] = _mk_fpq_module("FloatPointQuantizeTriton")
90
+ # Package-qualified under llava.model
91
+ sys.modules["llava.model.FloatPointQuantizeTorch"] = sys.modules["FloatPointQuantizeTorch"]
92
+ sys.modules["llava.model.FloatPointQuantizeTriton"] = sys.modules["FloatPointQuantizeTriton"]
93
 
94
  # ===============================
95
  # Load VILA
 
123
  # ===============================
124
  # Inference
125
  # ===============================
126
+ from PIL import Image as _PILImage
127
+
128
  def vila_infer(image, prompt):
129
  if image is None:
130
  return "Please upload an image."
131
  if not prompt or not str(prompt).strip():
132
  prompt = "Please describe the image."
133
 
134
+ pil = _PILImage.fromarray(image).convert("RGB")
135
 
136
  try:
137
  out = model.generate_content(
 
142
  {"type": "text", "value": prompt}
143
  ]
144
  }],
145
+ generation_config=None # default decoding
146
  )
147
  return str(out).strip()
148
  except Exception as e: