ShonaLLM / model /modeling_ngwanda.py
takuM23's picture
Upload folder using huggingface_hub
49c6760
import torch
from transformers import PreTrainedModel
from configuration_ngwanda import NgwandaConfig
from models.transformer_blocks import ShonaTransformer
import constants
class NgwandaModel(PreTrainedModel):
config_class = NgwandaConfig
def __init__(self, config):
super().__init__(config)
self.transformer = ShonaTransformer(
accelerate_instance=None,
d=config.d_model,
H=config.attention_heads,
T=config.sequence_length,
V=config.vocab_size,
layers=config.layers
)
self.post_init()
def forward(self, input_ids, **kwargs):
logits, _ = self.transformer(input_ids)
return type('CausalLMOutput', (object,), {'logits': logits})
def generate(self, input_ids, max_new_tokens, temperature=1.0, do_sample=True, **kwargs):
device = input_ids.device
generated = input_ids
self.eval()
with torch.no_grad():
for _ in range(max_new_tokens):
idx_cond = generated if generated.size(1) <= self.config.sequence_length else generated[:, -self.config.sequence_length:]
logits, _ = self.transformer(idx_cond)
# Temperature scaling
next_token_logits = logits[:, -1, :] / temperature
if do_sample:
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
return generated