Update modelling_sparrow.py
Browse files- modelling_sparrow.py +74 -38
modelling_sparrow.py
CHANGED
|
@@ -214,14 +214,7 @@ class SparrowModel(PreTrainedModel):
|
|
| 214 |
self.decoder.append(SparrowDecoderLayer(config=self.config, layer_idx=layer_idx))
|
| 215 |
|
| 216 |
self.norm = RMSNorm(dim=self.config.hidden_size)
|
| 217 |
-
self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=self.config.mlp_bias)
|
| 218 |
-
self.token_embedding.weight = self.output.weight
|
| 219 |
self.apply(self.weights_init)
|
| 220 |
-
self.loss = None
|
| 221 |
-
|
| 222 |
-
for pn, p in self.named_parameters():
|
| 223 |
-
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
| 224 |
-
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.num_hidden_layers))
|
| 225 |
|
| 226 |
def weights_init(self, module):
|
| 227 |
if isinstance(module, nn.Linear):
|
|
@@ -233,51 +226,94 @@ class SparrowModel(PreTrainedModel):
|
|
| 233 |
if module.padding_idx is not None:
|
| 234 |
module.weight.data[module.padding_idx].zero_()
|
| 235 |
|
| 236 |
-
def forward(self, input_ids,
|
| 237 |
x = self.dropout(self.token_embedding(input_ids))
|
| 238 |
|
| 239 |
for idx, layer in enumerate(self.decoder):
|
| 240 |
x = layer(x=x, use_kv_cache=use_kv_cache)
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
if labels is not None:
|
| 243 |
logits = self.output(x)
|
| 244 |
-
self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=
|
| 245 |
else:
|
| 246 |
logits = self.output(x[:, [-1], :])
|
| 247 |
self.loss = None
|
| 248 |
|
| 249 |
return CausalLMOutputWithPast(self.loss, logits)
|
| 250 |
|
| 251 |
-
@torch.
|
| 252 |
-
def generate(self, input_ids, eos, max_new_tokens, temperature=0.7, top_k=None,
|
| 253 |
-
use_kv_cache=True):
|
| 254 |
-
|
| 255 |
s = input_ids.shape[1]
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
else:
|
| 267 |
-
logits = logits / temperature
|
| 268 |
-
if top_k is not None:
|
| 269 |
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 270 |
-
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
if idx_next == eos:
|
| 276 |
break
|
| 277 |
-
|
| 278 |
-
input_ids = torch.cat((input_ids, idx_next), dim=1)
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
if not stream:
|
| 283 |
-
yield input_ids[:, s:]
|
|
|
|
| 214 |
self.decoder.append(SparrowDecoderLayer(config=self.config, layer_idx=layer_idx))
|
| 215 |
|
| 216 |
self.norm = RMSNorm(dim=self.config.hidden_size)
|
|
|
|
|
|
|
| 217 |
self.apply(self.weights_init)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
def weights_init(self, module):
|
| 220 |
if isinstance(module, nn.Linear):
|
|
|
|
| 226 |
if module.padding_idx is not None:
|
| 227 |
module.weight.data[module.padding_idx].zero_()
|
| 228 |
|
| 229 |
+
def forward(self, input_ids, use_kv_cache=False):
|
| 230 |
x = self.dropout(self.token_embedding(input_ids))
|
| 231 |
|
| 232 |
for idx, layer in enumerate(self.decoder):
|
| 233 |
x = layer(x=x, use_kv_cache=use_kv_cache)
|
| 234 |
|
| 235 |
+
return self.norm(x)
|
| 236 |
+
|
| 237 |
+
class SparrowModelForCausalLM(SparrowModel):
|
| 238 |
+
def __init__(self, config):
|
| 239 |
+
super().__init__(config)
|
| 240 |
+
self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=self.config.mlp_bias)
|
| 241 |
+
self.token_embedding.weight = self.output.weight
|
| 242 |
+
self.loss = None
|
| 243 |
+
|
| 244 |
+
for pn, p in self.named_parameters():
|
| 245 |
+
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
| 246 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.num_hidden_layers))
|
| 247 |
+
|
| 248 |
+
def forward(self, input_ids, labels=None, use_kv_cache=False):
|
| 249 |
+
x = super().forward(input_ids, use_kv_cache)
|
| 250 |
+
|
| 251 |
if labels is not None:
|
| 252 |
logits = self.output(x)
|
| 253 |
+
self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
|
| 254 |
else:
|
| 255 |
logits = self.output(x[:, [-1], :])
|
| 256 |
self.loss = None
|
| 257 |
|
| 258 |
return CausalLMOutputWithPast(self.loss, logits)
|
| 259 |
|
| 260 |
+
@torch.no_grad()
|
| 261 |
+
def generate(self, input_ids, eos=1, max_new_tokens=50, temperature=0.7, top_k=None, repetition_penalty=1.,
|
| 262 |
+
use_kv_cache=True, use_beam_search=False, beam_size=3):
|
|
|
|
| 263 |
s = input_ids.shape[1]
|
| 264 |
+
|
| 265 |
+
if use_beam_search:
|
| 266 |
+
sequences = [(input_ids, 0)] # List of (sequence, cumulative log probability)
|
| 267 |
+
for _ in range(max_new_tokens - 1):
|
| 268 |
+
all_candidates = []
|
| 269 |
+
for seq, score in sequences:
|
| 270 |
+
inference_res = self(seq, labels=None, use_kv_cache=use_kv_cache)
|
| 271 |
+
logits = inference_res.logits[:, -1, :]
|
| 272 |
+
|
| 273 |
+
if repetition_penalty != 1.0:
|
| 274 |
+
for token in set(seq.tolist()[0]):
|
| 275 |
+
logits[:, token] /= repetition_penalty
|
| 276 |
+
|
| 277 |
+
logits = logits / temperature if temperature > 0 else logits
|
| 278 |
+
probs = F.log_softmax(logits, dim=-1)
|
| 279 |
+
top_log_prob, idx_next = torch.topk(probs, beam_size, dim=-1)
|
| 280 |
+
|
| 281 |
+
for i in range(beam_size):
|
| 282 |
+
next_seq = torch.cat((seq, idx_next[:, i].unsqueeze(1)), dim=1)
|
| 283 |
+
next_score = score + top_log_prob[:, i].item()
|
| 284 |
+
all_candidates.append((next_seq, next_score))
|
| 285 |
+
|
| 286 |
+
sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_size]
|
| 287 |
+
if all(seq[0][:, -1].item() == eos for seq in sequences):
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
best_seq = sequences[0][0]
|
| 291 |
+
return best_seq.tolist()[0][s:]
|
| 292 |
+
|
| 293 |
+
# Greedy search (default)
|
| 294 |
+
generated_tokens = []
|
| 295 |
+
while len(generated_tokens) < max_new_tokens - 1:
|
| 296 |
+
inference_res = self(input_ids, labels=None, use_kv_cache=use_kv_cache)
|
| 297 |
+
logits = inference_res.logits[:, -1, :]
|
| 298 |
+
|
| 299 |
+
if repetition_penalty != 1.0:
|
| 300 |
+
for token in set(input_ids.tolist()[0]):
|
| 301 |
+
logits[:, token] /= repetition_penalty
|
| 302 |
+
|
| 303 |
+
if temperature == 0.0:
|
| 304 |
+
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
|
| 305 |
else:
|
| 306 |
+
logits = logits / temperature
|
| 307 |
+
if top_k is not None:
|
| 308 |
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 309 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 310 |
+
probs = F.softmax(logits, dim=-1)
|
| 311 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 312 |
+
|
| 313 |
+
if idx_next.item() == eos:
|
|
|
|
| 314 |
break
|
| 315 |
+
|
| 316 |
+
input_ids = torch.cat((input_ids, idx_next), dim=1)
|
| 317 |
+
generated_tokens.append(idx_next.item())
|
| 318 |
+
|
| 319 |
+
return generated_tokens
|
|
|
|
|
|