Ryan Robson commited on
Commit
2f0f602
Β·
1 Parent(s): 2f155f3

Fix LoRA model loading

Browse files

- Load base Mistral-7B model first
- Apply LoRA adapter with PeftModel
- Add peft to requirements
- Fix torch_dtype deprecation warning

Files changed (2) hide show
  1. app.py +11 -5
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,21 +1,27 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import torch
4
 
5
  print("πŸ”„ Loading CCISD TEKS Educational Assistant...")
6
  print(" This may take 1-2 minutes on first launch...")
7
 
8
- # Load your trained model
9
- MODEL_NAME = "robworks-software/ccisd-teks-educator-mistral7b"
 
10
 
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
- MODEL_NAME,
14
- torch_dtype=torch.float16,
15
  device_map="auto",
16
  low_cpu_mem_usage=True
17
  )
18
 
 
 
 
19
  print("βœ… Model loaded successfully!")
20
 
21
  def chat(message, history):
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
  import torch
5
 
6
  print("πŸ”„ Loading CCISD TEKS Educational Assistant...")
7
  print(" This may take 1-2 minutes on first launch...")
8
 
9
+ # Load base model + LoRA adapter
10
+ BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
11
+ ADAPTER_MODEL = "robworks-software/ccisd-teks-educator-mistral7b"
12
 
13
+ print(f"πŸ“₯ Loading base model: {BASE_MODEL}...")
14
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
15
  model = AutoModelForCausalLM.from_pretrained(
16
+ BASE_MODEL,
17
+ dtype=torch.float16,
18
  device_map="auto",
19
  low_cpu_mem_usage=True
20
  )
21
 
22
+ print(f"πŸ”§ Loading LoRA adapter: {ADAPTER_MODEL}...")
23
+ model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
24
+
25
  print("βœ… Model loaded successfully!")
26
 
27
  def chat(message, history):
requirements.txt CHANGED
@@ -2,5 +2,6 @@ gradio>=4.0.0
2
  transformers>=4.35.0
3
  torch>=2.0.0
4
  accelerate>=0.24.0
 
5
  sentencepiece>=0.1.99
6
  protobuf>=3.20.0
 
2
  transformers>=4.35.0
3
  torch>=2.0.0
4
  accelerate>=0.24.0
5
+ peft>=0.7.0
6
  sentencepiece>=0.1.99
7
  protobuf>=3.20.0