Ayushnangia commited on
Commit
2bdb98a
·
verified ·
1 Parent(s): 7916fa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -1,5 +1,7 @@
1
  # login as a privileged user.
2
  import os
 
 
3
  HF_TOKEN = os.environ.get("HF_TOKEN")
4
 
5
  from huggingface_hub import login
@@ -50,7 +52,17 @@ if torch.cuda.is_available():
50
  model = AutoModelForCausalLM.from_pretrained(
51
  model_id, device_map="cuda", torch_dtype=torch.bfloat16
52
  )
53
- reft_model = ReftModel.load("Ayushnangia/Lossfunk-Residency-Llama-3-8B-Instruct", model, from_huggingface_hub=True)
 
 
 
 
 
 
 
 
 
 
54
  reft_model.set_device("cuda")
55
  tokenizer = AutoTokenizer.from_pretrained(model_id)
56
  tokenizer.use_default_system_prompt = True
@@ -209,7 +221,6 @@ chat_interface = gr.ChatInterface(
209
 
210
  with gr.Blocks(css="style.css") as demo:
211
  gr.Markdown(DESCRIPTION)
212
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
213
  chat_interface.render()
214
  gr.Markdown(LICENSE)
215
 
 
1
  # login as a privileged user.
2
  import os
3
+ import subprocess
4
+
5
  HF_TOKEN = os.environ.get("HF_TOKEN")
6
 
7
  from huggingface_hub import login
 
52
  model = AutoModelForCausalLM.from_pretrained(
53
  model_id, device_map="cuda", torch_dtype=torch.bfloat16
54
  )
55
+ # Define local repository path for the ReFT model files
56
+ repo_local_path = "./Lossfunk-Residency-Llama-3-8B-Instruct"
57
+ if not os.path.exists(repo_local_path):
58
+ print("Local repository not found. Cloning repository using git lfs...")
59
+ subprocess.run(
60
+ ["git", "lfs", "clone", "https://huggingface.co/Ayushnangia/Lossfunk-Residency-Llama-3-8B-Instruct", repo_local_path],
61
+ check=True
62
+ )
63
+
64
+ # Load the ReFT model from the local repository
65
+ reft_model = ReftModel.load(repo_local_path, model, from_huggingface_hub=False)
66
  reft_model.set_device("cuda")
67
  tokenizer = AutoTokenizer.from_pretrained(model_id)
68
  tokenizer.use_default_system_prompt = True
 
221
 
222
  with gr.Blocks(css="style.css") as demo:
223
  gr.Markdown(DESCRIPTION)
 
224
  chat_interface.render()
225
  gr.Markdown(LICENSE)
226