ColdSlim commited on
Commit
4fcbf22
·
verified ·
1 Parent(s): 84f8470

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -1
app.py CHANGED
@@ -17,6 +17,8 @@ from typing import Optional, Tuple
17
  import gradio as gr
18
  import spaces
19
  import torch
 
 
20
  from PIL import Image
21
  from peft import PeftModel
22
  from transformers import AutoProcessor
@@ -141,7 +143,8 @@ def _load_base_plus_lora(dtype: torch.dtype = torch.float16):
141
  low_cpu_mem_usage=True,
142
  )
143
  logger.info(f"Attaching LoRA adapters: {ADAPTER_ID}")
144
- model = PeftModel.from_pretrained(base, ADAPTER_ID, is_trainable=False)
 
145
  model.eval()
146
  return model
147
 
@@ -200,6 +203,37 @@ def create_interface() -> gr.Blocks:
200
  )
201
  return demo
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def main():
204
  demo = create_interface()
205
  demo.launch(
 
17
  import gradio as gr
18
  import spaces
19
  import torch
20
+ import json, tempfile, shutil
21
+ from huggingface_hub import snapshot_download
22
  from PIL import Image
23
  from peft import PeftModel
24
  from transformers import AutoProcessor
 
143
  low_cpu_mem_usage=True,
144
  )
145
  logger.info(f"Attaching LoRA adapters: {ADAPTER_ID}")
146
+ adapter_path = _get_sanitized_adapter_dir(ADAPTER_ID)
147
+ model = PeftModel.from_pretrained(base, adapter_path, is_trainable=False)
148
  model.eval()
149
  return model
150
 
 
203
  )
204
  return demo
205
 
206
+ def _get_sanitized_adapter_dir(adapter_id: str) -> str:
207
+ """
208
+ Download the adapter repo and remove unsupported keys from adapter_config.json
209
+ so PEFT can load it.
210
+ """
211
+ repo_dir = snapshot_download(adapter_id)
212
+ tmp_dir = tempfile.mkdtemp(prefix="peft_adapter_")
213
+ shutil.copytree(repo_dir, os.path.join(tmp_dir, "adapter"), dirs_exist_ok=True)
214
+ cfg_path = os.path.join(tmp_dir, "adapter", "adapter_config.json")
215
+
216
+ try:
217
+ with open(cfg_path, "r") as f:
218
+ cfg = json.load(f)
219
+ except Exception as e:
220
+ raise RuntimeError(f"Failed to read adapter_config.json: {e}")
221
+
222
+ # Remove keys PEFT LoraConfig doesn't recognize
223
+ for k in ["corda_config", "CoRDA_config"]:
224
+ if k in cfg:
225
+ cfg.pop(k)
226
+
227
+ # If DoRA wasn’t used, drop its block too
228
+ if str(cfg.get("use_dora", "false")).lower() in ["false", "0", "no"]:
229
+ if "dora_config" in cfg:
230
+ cfg.pop("dora_config")
231
+
232
+ with open(cfg_path, "w") as f:
233
+ json.dump(cfg, f, indent=2)
234
+
235
+ return os.path.join(tmp_dir, "adapter")
236
+
237
  def main():
238
  demo = create_interface()
239
  demo.launch(