omdeep22 commited on
Commit
0e7ffd2
·
verified ·
1 Parent(s): 27ddd0d

Fix: float causal mask, weight tying, attention_mask, config aliases

Browse files
Files changed (1) hide show
  1. modeling_konkan.py +28 -1
modeling_konkan.py CHANGED
@@ -117,4 +117,31 @@ class KonkanGPT(PreTrainedModel, GenerationMixin):
117
  return CausalLMOutput(loss=loss, logits=logits)
118
 
119
  def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
120
- return {"input_ids": input_ids, "attention_mask": attention_mask}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  return CausalLMOutput(loss=loss, logits=logits)
118
 
119
  def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
120
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
121
+
122
+ def chat(self, tokenizer, query, max_new_tokens=200, temperature=0.7):
123
+ device = next(self.parameters()).device
124
+
125
+ # 1. Apply Template Inbuilt
126
+ messages = [{"role": "user", "content": query}]
127
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
128
+
129
+ # 2. Tokenize Inbuilt
130
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(device)
131
+
132
+ # 3. Generate
133
+ outputs = self.generate(
134
+ **inputs,
135
+ max_new_tokens=max_new_tokens,
136
+ do_sample=True,
137
+ temperature=temperature,
138
+ top_p=0.9,
139
+ repetition_penalty=1.2,
140
+ pad_token_id=tokenizer.pad_token_id,
141
+ eos_token_id=tokenizer.eos_token_id
142
+ )
143
+
144
+ # 4. Clean and Return
145
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
146
+ response = decoded.split("assistant")[-1].strip()
147
+ return response.replace("|>", "").strip()