FarhanAK128 commited on
Commit
19cd814
·
verified ·
1 Parent(s): 9bc876e

Update model_class.py

Browse files
Files changed (1) hide show
  1. model_class.py +74 -0
model_class.py CHANGED
@@ -174,3 +174,77 @@ class CustomGPT(
174
  x = self.final_norm(x)
175
  logits = self.out_head(x) #[2,4,50257]
176
  return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  x = self.final_norm(x)
175
  logits = self.out_head(x) #[2,4,50257]
176
  return logits
177
+
178
+ def format_input(self, entry):
179
+ instruction_text = (
180
+ f"Below is an instruction that describes a task. "
181
+ f"Write a response that appropriately completes the request."
182
+ f"\n\n### Instruction:\n{entry['instruction']}"
183
+ )
184
+ input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""
185
+ return instruction_text + input_text
186
+
187
+ def text_to_token_ids(self, text, tokenizer):
188
+ encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
189
+ encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
190
+ return encoded_tensor
191
+
192
+ def token_ids_to_text(self, token_ids, tokenizer):
193
+ flat = token_ids.squeeze(0) # remove batch dimension
194
+ return tokenizer.decode(flat.tolist())
195
+
196
+ def generate(self, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
197
+ for _ in range(max_new_tokens):
198
+ idx_cond = idx[:, -context_size:]
199
+ with torch.no_grad():
200
+ logits = self(idx_cond)
201
+ logits = logits[:, -1, :]
202
+
203
+ if top_k is not None:
204
+ # Keep only top_k values
205
+ top_logits, _ = torch.topk(logits, top_k)
206
+ min_val = top_logits[:, -1] # select the last element i.e., the smallest from each batch's output
207
+ logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
208
+
209
+ # New: Apply temperature scaling
210
+ if temperature > 0.0:
211
+ logits = logits / temperature
212
+
213
+ # Apply softmax to get probabilities
214
+ probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
215
+
216
+ # Sample from the distribution
217
+ idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
218
+
219
+ # Otherwise same as before: get idx of the vocab entry with the highest logits value
220
+ else:
221
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
222
+
223
+ if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
224
+ break
225
+
226
+ # Same as before: append sampled index to the running sequence
227
+ idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
228
+
229
+ return idx
230
+
231
+ def generate_response(self, input_entry, temperature=0.0, topk=None):
232
+ current_device = next(self.parameters()).device
233
+ self.eval()
234
+ input_text = self.format_input(entry)
235
+
236
+ token_ids = generate(
237
+ idx=self.text_to_token_ids(input_text, tokenizer).to(current_device),
238
+ max_new_tokens=256,
239
+ context_size=1024,
240
+ temperatue=temperature,
241
+ topk=topk,
242
+ eos_id=50256
243
+ )
244
+ generated_text = self.token_ids_to_text(token_ids, tokenizer)
245
+ response_text = (
246
+ generated_text[len(input_text):]
247
+ .replace("### Response:", "")
248
+ .strip()
249
+ )
250
+ return response_text.strip()