ColdSlim commited on
Commit
639fd4b
·
verified ·
1 Parent(s): 4fcbf22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -27
app.py CHANGED
@@ -130,15 +130,11 @@ def format_derm_disclaimer(ans: str) -> str:
130
  )
131
 
132
  def _load_base_plus_lora(dtype: torch.dtype = torch.float16):
133
- """
134
- Load base model on GPU and attach LoRA adapters from ADAPTER_ID.
135
- Returns the Peft-wrapped model (eval mode).
136
- """
137
  logger.info(f"Loading BASE on GPU: {BASE_MODEL_ID}")
138
  base = VisionTextModelClass.from_pretrained(
139
  BASE_MODEL_ID,
140
  torch_dtype=dtype,
141
- device_map="cuda", # ZeroGPU worker device
142
  trust_remote_code=True,
143
  low_cpu_mem_usage=True,
144
  )
@@ -148,6 +144,7 @@ def _load_base_plus_lora(dtype: torch.dtype = torch.float16):
148
  model.eval()
149
  return model
150
 
 
151
  # ---------------------------
152
  # Inference (ZeroGPU-safe: only here we touch CUDA)
153
  # ---------------------------
@@ -205,34 +202,62 @@ def create_interface() -> gr.Blocks:
205
 
206
  def _get_sanitized_adapter_dir(adapter_id: str) -> str:
207
  """
208
- Download the adapter repo and remove unsupported keys from adapter_config.json
209
- so PEFT can load it.
210
  """
211
  repo_dir = snapshot_download(adapter_id)
212
- tmp_dir = tempfile.mkdtemp(prefix="peft_adapter_")
213
- shutil.copytree(repo_dir, os.path.join(tmp_dir, "adapter"), dirs_exist_ok=True)
214
- cfg_path = os.path.join(tmp_dir, "adapter", "adapter_config.json")
215
-
216
- try:
217
- with open(cfg_path, "r") as f:
218
- cfg = json.load(f)
219
- except Exception as e:
220
- raise RuntimeError(f"Failed to read adapter_config.json: {e}")
221
-
222
- # Remove keys PEFT LoraConfig doesn't recognize
223
- for k in ["corda_config", "CoRDA_config"]:
224
- if k in cfg:
225
- cfg.pop(k)
226
-
227
- # If DoRA wasn’t used, drop its block too
228
- if str(cfg.get("use_dora", "false")).lower() in ["false", "0", "no"]:
229
- if "dora_config" in cfg:
230
- cfg.pop("dora_config")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  with open(cfg_path, "w") as f:
233
  json.dump(cfg, f, indent=2)
234
 
235
- return os.path.join(tmp_dir, "adapter")
236
 
237
  def main():
238
  demo = create_interface()
 
130
  )
131
 
132
  def _load_base_plus_lora(dtype: torch.dtype = torch.float16):
 
 
 
 
133
  logger.info(f"Loading BASE on GPU: {BASE_MODEL_ID}")
134
  base = VisionTextModelClass.from_pretrained(
135
  BASE_MODEL_ID,
136
  torch_dtype=dtype,
137
+ device_map="cuda",
138
  trust_remote_code=True,
139
  low_cpu_mem_usage=True,
140
  )
 
144
  model.eval()
145
  return model
146
 
147
+
148
  # ---------------------------
149
  # Inference (ZeroGPU-safe: only here we touch CUDA)
150
  # ---------------------------
 
202
 
203
  def _get_sanitized_adapter_dir(adapter_id: str) -> str:
204
  """
205
+ Download the adapter repo locally and sanitize adapter_config.json so PEFT LoraConfig
206
+ can parse it (remove any unknown / experimental blocks like 'corda_config', 'eva_config', etc.).
207
  """
208
  repo_dir = snapshot_download(adapter_id)
209
+ tmp_root = tempfile.mkdtemp(prefix="peft_adapter_")
210
+ adapter_dir = os.path.join(tmp_root, "adapter")
211
+ shutil.copytree(repo_dir, adapter_dir, dirs_exist_ok=True)
212
+
213
+ cfg_path = os.path.join(adapter_dir, "adapter_config.json")
214
+ if not os.path.isfile(cfg_path):
215
+ raise RuntimeError(f"adapter_config.json not found in adapter repo: {adapter_id}")
216
+
217
+ with open(cfg_path, "r") as f:
218
+ cfg = json.load(f)
219
+
220
+ # Minimal, widely-supported PEFT LoRA keys:
221
+ allowed = {
222
+ "peft_type", "task_type",
223
+ "r", "lora_alpha", "lora_dropout",
224
+ "target_modules", "bias",
225
+ "inference_mode",
226
+ "base_model_name_or_path",
227
+ "fan_in_fan_out",
228
+ "modules_to_save",
229
+ "layers_to_transform",
230
+ "layers_pattern",
231
+ "use_rslora",
232
+ "rank_dropout", "module_dropout",
233
+ "init_lora_weights",
234
+ "use_dora", # keep if your PEFT version supports DoRA; harmless otherwise if False
235
+ }
236
+
237
+ # If DoRA isn't actually used, nuke its block to avoid parser issues
238
+ if str(cfg.get("use_dora", "false")).lower() in ("false", "0", "no"):
239
+ cfg.pop("dora_config", None)
240
+
241
+ # Drop any unknown top-level configs (e.g., 'corda_config', 'CoRDA_config', 'eva_config', etc.)
242
+ to_delete = [k for k in cfg.keys() if k not in allowed]
243
+ for k in to_delete:
244
+ cfg.pop(k, None)
245
+
246
+ # Ensure required fields exist
247
+ cfg.setdefault("peft_type", "LORA")
248
+ cfg.setdefault("task_type", "CAUSAL_LM")
249
+ cfg.setdefault("bias", "none")
250
+ cfg.setdefault("inference_mode", True)
251
+
252
+ # Normalize booleans in case they were strings
253
+ for k in ("inference_mode", "use_rslora", "use_dora", "fan_in_fan_out"):
254
+ if k in cfg and isinstance(cfg[k], str):
255
+ cfg[k] = cfg[k].lower() in ("true", "1", "yes")
256
 
257
  with open(cfg_path, "w") as f:
258
  json.dump(cfg, f, indent=2)
259
 
260
+ return adapter_dir
261
 
262
  def main():
263
  demo = create_interface()