| | from typing import Optional, Tuple |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange, repeat |
| |
|
| | from .config import LlamaConfig |
| | from .model import LlamaModel |
| |
|
| | class LlamaForCausalLM(nn.Module): |
| | def __init__(self, config: LlamaConfig): |
| | super().__init__() |
| | self.model = LlamaModel(config) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | |
| | |
| | if config.tie_word_embeddings: |
| | self.lm_head.weight = self.model.embed_tokens.weight |
| | |
| | self._init_weights() |
| |
|
| | def _init_weights(self): |
| | """Initialize weights for all layers.""" |
| | |
| | if hasattr(self.model, 'embed_tokens'): |
| | nn.init.normal_(self.model.embed_tokens.weight, mean=0.0, std=0.041666666666666664) |
| |
|
| | |
| | for module in self.modules(): |
| | if isinstance(module, nn.Linear): |
| | |
| | nn.init.xavier_uniform_(module.weight) |
| | if module.bias is not None: |
| | |
| | nn.init.zeros_(module.bias) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | hidden_states = self.model( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | ) |
| | |
| | return hidden_states, self.lm_head.weight |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | input_ids: torch.LongTensor, |
| | max_new_tokens: int = 30, |
| | temperature: float = 0.0, |
| | ) -> torch.LongTensor: |
| | self.eval() |
| | bsz, seq_len = input_ids.shape |
| | |
| | position_ids = repeat( |
| | torch.arange(seq_len, device=input_ids.device), |
| | 'l -> b l', |
| | b=bsz |
| | ) |
| | |
| | for _ in range(max_new_tokens): |
| | hidden_states, classifier_weights = self.forward(input_ids, position_ids=position_ids) |
| | |
| | |
| | next_token_logits = hidden_states[:, -1] @ classifier_weights.T |
| | |
| | if temperature == 0: |
| | next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| | else: |
| | scaled_logits = next_token_logits / temperature |
| | probs = torch.softmax(scaled_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | |
| | input_ids = torch.cat([input_ids, next_token], dim=1) |
| | new_position_ids = position_ids[:, -1:] + 1 |
| | position_ids = torch.cat([position_ids, new_position_ids], dim=1) |
| | |
| | return input_ids |