arjun-ms commited on
Commit
c4e10fc
·
verified ·
1 Parent(s): e9b5538

feat: merge lora adapters to mistral base model

Browse files
Files changed (1) hide show
  1. app.py +36 -16
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  import gradio as gr
4
  from huggingface_hub import login
@@ -8,11 +9,17 @@ from peft import PeftModel
8
  # HF token from environment
9
  token = os.environ.get("HF_TOKEN")
10
  if token:
 
11
  login(token)
 
 
12
 
13
  # Model & adapter
14
- base_id = "mistralai/Mistral-7B-Instruct-v0.2"
15
- lora_id = "arjun-ms/mistral-lora-ipc"
 
 
 
16
 
17
  # ✅ Quantization config
18
  bnb_config = BitsAndBytesConfig(
@@ -23,24 +30,37 @@ bnb_config = BitsAndBytesConfig(
23
  llm_int8_enable_fp32_cpu_offload=True # CPU offloading
24
  )
25
 
26
- # Tokenizer
27
- tokenizer = AutoTokenizer.from_pretrained(base_id)
28
 
29
- # Base model with quantization
30
- base_model = AutoModelForCausalLM.from_pretrained(
31
- base_id,
32
- quantization_config=bnb_config,
33
- device_map="auto", # Automatically offloads to CPU/GPU/disk
34
- trust_remote_code=True,
35
- torch_dtype=torch.float16 # Helps reduce memory use
36
- )
37
 
 
 
 
 
 
 
 
 
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
- # Apply your LoRA adapter
43
- model = PeftModel.from_pretrained(base_model, lora_id)
44
 
45
  # Response generator
46
  def generate_response(prompt):
@@ -60,8 +80,8 @@ iface = gr.Interface(
60
  fn=generate_response,
61
  inputs="text",
62
  outputs="text",
63
- title="IPC LoRA on Mistral 7B",
64
  description="LoRA fine-tuned Mistral 7B for Indian Penal Code questions"
65
  )
66
 
67
- iface.launch()
 
1
  import os
2
+ from pydantic import BaseModel
3
  import torch
4
  import gradio as gr
5
  from huggingface_hub import login
 
9
  # HF token from environment
10
  token = os.environ.get("HF_TOKEN")
11
  if token:
12
+ print("HF TOKEN FOUND AND LOADING....")
13
  login(token)
14
+ else:
15
+ print("NO HF TOKEN FOUND!")
16
 
17
  # Model & adapter
18
+ # base_id = "mistralai/Mistral-7B-Instruct-v0.2"
19
+ # lora_id = "arjun-ms/mistral-lora-ipc"
20
+
21
+ # Merged model path
22
+ merged_model_id = "arjun-ms/mistral-7b-ipc-merged"
23
 
24
  # ✅ Quantization config
25
  bnb_config = BitsAndBytesConfig(
 
30
  llm_int8_enable_fp32_cpu_offload=True # CPU offloading
31
  )
32
 
 
 
33
 
34
+ # # Tokenizer
35
+ # tokenizer = AutoTokenizer.from_pretrained(base_id)
 
 
 
 
 
 
36
 
37
+ # # Base model with quantization
38
+ # base_model = AutoModelForCausalLM.from_pretrained(
39
+ # base_id,
40
+ # quantization_config=bnb_config,
41
+ # device_map="auto", # Automatically offloads to CPU/GPU/disk
42
+ # trust_remote_code=True,
43
+ # torch_dtype=torch.float16 # Helps reduce memory use
44
+ # )
45
 
46
 
47
+ # Apply your LoRA adapter
48
+ # model = PeftModel.from_pretraineBaseModelel, lora_id)
49
+
50
+
51
+ # ✅ Tokenizer for merged model
52
+ tokenizer = AutoTokenizer.from_pretrained(merged_model_id)
53
+
54
+ # ✅ Load merged model
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ merged_model_id,
57
+ quantization_config=bnb_config,
58
+ device_map="auto",
59
+ trust_remote_code=True,
60
+ torch_dtype=torch.float16
61
+ )
62
 
63
 
 
 
64
 
65
  # Response generator
66
  def generate_response(prompt):
 
80
  fn=generate_response,
81
  inputs="text",
82
  outputs="text",
83
+ title="IPC Mistral 7B",
84
  description="LoRA fine-tuned Mistral 7B for Indian Penal Code questions"
85
  )
86
 
87
+ iface.launch()