|
|
| import torch |
| from transformers import PreTrainedModel |
| from configuration_ngwanda import NgwandaConfig |
| from models.transformer_blocks import ShonaTransformer |
| import constants |
|
|
| class NgwandaModel(PreTrainedModel): |
| config_class = NgwandaConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.transformer = ShonaTransformer( |
| accelerate_instance=None, |
| d=config.d_model, |
| H=config.attention_heads, |
| T=config.sequence_length, |
| V=config.vocab_size, |
| layers=config.layers |
| ) |
| self.post_init() |
| |
| def forward(self, input_ids, **kwargs): |
| logits, _ = self.transformer(input_ids) |
| return type('CausalLMOutput', (object,), {'logits': logits}) |
|
|
| def generate(self, input_ids, max_new_tokens, temperature=1.0, do_sample=True, **kwargs): |
| device = input_ids.device |
| generated = input_ids |
| self.eval() |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| idx_cond = generated if generated.size(1) <= self.config.sequence_length else generated[:, -self.config.sequence_length:] |
| logits, _ = self.transformer(idx_cond) |
| |
| |
| next_token_logits = logits[:, -1, :] / temperature |
| |
| if do_sample: |
| probs = torch.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| else: |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| |
| generated = torch.cat([generated, next_token], dim=1) |
| return generated |
|
|