rishabhsetiya commited on
Commit
eff248a
·
verified ·
1 Parent(s): f58dfef

Create fine_tuning.py

Browse files
Files changed (1) hide show
  1. fine_tuning.py +194 -0
fine_tuning.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import torch
5
+ import pandas as pd
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from datasets import Dataset
9
+ import transformers
10
+ from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
11
+ from peft import LoraConfig, get_peft_model
12
+
13
+ # -----------------------------
14
+ # ENVIRONMENT / CACHE
15
+ # -----------------------------
16
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
17
+ os.environ["HF_HOME"] = "/tmp/huggingface_cache"
18
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface_cache"
19
+ os.environ["HF_METRICS_CACHE"] = "/tmp/huggingface_cache"
20
+ os.environ["WANDB_MODE"] = "disabled"
21
+
22
+ # -----------------------------
23
+ # SETTINGS
24
+ # -----------------------------
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ tokenizer = transformers.AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
27
+
28
+ # -----------------------------
29
+ # LoRA / MoE Modules
30
+ # -----------------------------
31
+ class LoraLinear(nn.Module):
32
+ def __init__(self, in_features, out_features, r=8, lora_alpha=16, lora_dropout=0.05, bias=False):
33
+ super().__init__()
34
+ self.in_features = in_features
35
+ self.out_features = out_features
36
+ self.r = r
37
+ self.scaling = lora_alpha / r if r > 0 else 1.0
38
+ self.weight = nn.Parameter(torch.empty(out_features, in_features), requires_grad=False)
39
+ self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=False) if bias else None
40
+
41
+ if r > 0:
42
+ self.lora_A = nn.Parameter(torch.zeros((r, in_features)))
43
+ self.lora_B = nn.Parameter(torch.zeros((out_features, r)))
44
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
45
+ nn.init.zeros_(self.lora_B)
46
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
47
+ else:
48
+ self.lora_A, self.lora_B, self.lora_dropout = None, None, None
49
+
50
+ def forward(self, x):
51
+ result = F.linear(x, self.weight, self.bias)
52
+ if self.r > 0:
53
+ lora_out = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
54
+ result = result + self.scaling * lora_out
55
+ return result
56
+
57
+ class MoELoRALinear(nn.Module):
58
+ def __init__(self, base_linear, r, num_experts=2, k=1, lora_alpha=16, lora_dropout=0.05):
59
+ super().__init__()
60
+ self.base_linear = base_linear
61
+ self.num_experts = num_experts
62
+ self.k = k
63
+ self.experts = nn.ModuleList([
64
+ LoraLinear(base_linear.in_features, base_linear.out_features, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
65
+ for _ in range(num_experts)
66
+ ])
67
+ self.gate = nn.Linear(base_linear.in_features, num_experts)
68
+
69
+ def forward(self, x):
70
+ base_out = self.base_linear(x)
71
+ gate_scores = torch.softmax(self.gate(x), dim=-1)
72
+ expert_out = 0
73
+ for i, expert in enumerate(self.experts):
74
+ expert_out += gate_scores[..., i:i+1] * expert(x)
75
+ return base_out + expert_out
76
+
77
+ def replace_proj_with_moe_lora(model, r=8, num_experts=2, k=1, lora_alpha=16, lora_dropout=0.05):
78
+ for layer in model.model.layers:
79
+ for proj_name in ["up_proj", "down_proj"]:
80
+ old = getattr(layer.mlp, proj_name)
81
+ moe = MoELoRALinear(
82
+ base_linear=old,
83
+ r=r,
84
+ num_experts=num_experts,
85
+ k=k,
86
+ lora_alpha=lora_alpha,
87
+ lora_dropout=lora_dropout,
88
+ ).to(next(old.parameters()).device)
89
+ setattr(layer.mlp, proj_name, moe)
90
+ return model
91
+
92
+ # -----------------------------
93
+ # DATA PREPROCESSING
94
+ # -----------------------------
95
+ def preprocess(example):
96
+ tokens = tokenizer(example['text'], truncation=True, padding=False)
97
+ text = example['text']
98
+ assistant_index = text.find("<|assistant|>")
99
+ prefix_ids = tokenizer(text[:assistant_index], add_special_tokens=False)['input_ids']
100
+ prefix_len = len(prefix_ids)
101
+ labels = tokens['input_ids'].copy()
102
+ labels[:prefix_len] = [-100] * prefix_len
103
+ tokens['labels'] = labels
104
+ return tokens
105
+
106
+ # -----------------------------
107
+ # LOAD & TRAIN MODEL
108
+ # -----------------------------
109
+ def load_and_train(model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
110
+ current_dir = os.path.dirname(os.path.abspath(__file__))
111
+ json_file_path = os.path.join(current_dir, 'makemytrip_qa_full.json')
112
+
113
+ with open(json_file_path, 'r', encoding='utf-8') as f:
114
+ data = json.load(f)
115
+
116
+ df = pd.DataFrame(data)
117
+ print(f"Loaded dataset containing {len(df)} questions")
118
+
119
+ system_prompt = "You are a helpful assistant that provides financial data from MakeMyTrip reports."
120
+ training_data = [
121
+ {"text": f"<|system|>\n{system_prompt}</s>\n<|user|>\n{row['question']}</s>\n<|assistant|>\n{row['answer']}</s>"}
122
+ for _, row in df.iterrows()
123
+ ]
124
+ dataset = Dataset.from_list(training_data)
125
+ tokenized_dataset = dataset.map(preprocess, remove_columns=["text"])
126
+
127
+ base_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(device)
128
+ model = replace_proj_with_moe_lora(base_model)
129
+ peft_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05, target_modules=["o_proj"], bias="none", task_type="CAUSAL_LM")
130
+ model = get_peft_model(model, peft_config)
131
+
132
+ model.config.use_cache = False
133
+ model.gradient_checkpointing_disable()
134
+
135
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
136
+
137
+ training_args = TrainingArguments(
138
+ learning_rate=5e-5,
139
+ output_dir="./results",
140
+ num_train_epochs=2,
141
+ per_device_train_batch_size=1,
142
+ gradient_accumulation_steps=4,
143
+ logging_steps=1,
144
+ save_steps=10,
145
+ save_total_limit=2,
146
+ fp16=True,
147
+ bf16=False,
148
+ )
149
+
150
+ trainer = Trainer(
151
+ model=model,
152
+ args=training_args,
153
+ train_dataset=tokenized_dataset,
154
+ data_collator=data_collator
155
+ )
156
+
157
+ print("Training started")
158
+ trainer.train()
159
+ model.eval()
160
+ return model, tokenizer, device
161
+
162
+ # -----------------------------
163
+ # GENERATE ANSWER
164
+ # -----------------------------
165
+ def generate_answer(model, tokenizer, device, prompt, max_tokens=200):
166
+ if prompt.strip() == "":
167
+ return "Please enter a prompt!"
168
+
169
+ system_prompt = "You are a helpful assistant that provides financial data from MakeMyTrip reports."
170
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
171
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
172
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
173
+
174
+ with torch.no_grad():
175
+ outputs = model.generate(
176
+ **inputs,
177
+ max_new_tokens=max_tokens,
178
+ do_sample=True,
179
+ top_p=0.9,
180
+ temperature=0.7,
181
+ )
182
+
183
+ decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
184
+
185
+ answer_start_token = '<|assistant|>'
186
+ answer_start_index = decoded_output.rfind(answer_start_token)
187
+ if answer_start_index != -1:
188
+ generated_answer = decoded_output[answer_start_index + len(answer_start_token):].strip()
189
+ if generated_answer.endswith('</s>'):
190
+ generated_answer = generated_answer[:-len('</s>')].strip()
191
+ else:
192
+ generated_answer = "Could not extract answer from model output."
193
+
194
+ return generated_answer