Model Description
This is a Transformer-based causal language model trained on the Jamaican Tiny Stories dataset. The model architecture is a decoder-only transformer with positional encoding.
- Model type: Causal Language Model
- Architecture: Transformer Decoder
- Tokenizer: Based on the tokenizer from "roneneldan/TinyStories-33M"
- Vocabulary Size: 50257
Training Code
https://colab.research.google.com/drive/1BcWKK3jBNvJk-uDpG3zFyN3HWsLP8nPo?usp=sharing
Training Data
The model was trained on the "Jaia-Admin/JamaicanTinyStories_400K_sample" dataset, which contains Jamaican Patois stories.
Training Procedure
- Epochs: 20 (though training was interrupted)
- Batch Size: 16
- Optimizer: AdamW with a learning rate of 5e-5
- Loss Function: Cross-Entropy Loss
- Mixed Precision Training: Enabled using
torch.amp.autocastandGradScaler. - Learning Rate Scheduler: ReduceLROnPlateau
Model Parameters
- Embedding Size: 768
- Number of Layers: 3
- Number of Heads: 8
- Feed-Forward Dimension: 768 * 4 = 3072
- Maximum Sequence Length: 512
Usage
The model can be used to generate text based on a given prompt. The generate method in the TransformerGenerator class handles the text generation with options for controlling the output (temperature, top-k, top-p).
from transformers import AutoTokenizer
import math
import torch
import os
from torch import nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the tokenizer
model_name = "JAMAICA-AI/JamaicanTinyStories"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Define model parameters (should match the trained model)
vocab_size = tokenizer.vocab_size
embed_size = 768
num_layers = 3
num_heads = 8
ff_dim = 768 * 4
max_len = 512
class PositionalEncoding(nn.Module):
def __init__(self, d_model=768, dropout=0.1, max_len=1024):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(1), :].unsqueeze(0)
return self.dropout(x)
class TransformerGenerator(nn.Module):
def __init__(self, vocab_size, embed_size, num_layers, num_heads, ff_dim, max_len=512):
super(TransformerGenerator, self).__init__()
self.token_embedding = nn.Embedding(vocab_size, embed_size)
self.pos_encoding = PositionalEncoding(d_model=embed_size, max_len=max_len)
decoder_layer = nn.TransformerDecoderLayer(
d_model=embed_size,
nhead=num_heads,
dim_feedforward=ff_dim,
batch_first=True,
activation=F.gelu,
bias=False
)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.fc_out = nn.Linear(embed_size, vocab_size)
self.max_len = max_len
def forward(self, inputs, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
inputs = self.token_embedding(inputs)
inputs = self.pos_encoding(inputs)
# In a decoder-only model, the target and memory are the same.
causal_mask = nn.Transformer.generate_square_subsequent_mask(inputs.size(1)).to(device)
output = self.transformer_decoder(inputs, inputs, tgt_mask=causal_mask, memory_mask=causal_mask)
output = self.fc_out(output)
return output
def generate(self, prompt_ids, max_length=100, temperature=1.0, top_k=50, top_p=0.95, eos_token_id=None):
self.eval()
with torch.no_grad():
input_ids = prompt_ids.to(self.fc_out.weight.device)
for _ in range(max_length):
# Create a causal mask for the decoder
seq_len = input_ids.size(1)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf')).masked_fill(causal_mask == 0, float(0.0))
# Get predictions
output = self.forward(input_ids, tgt_mask=causal_mask)
# Get the logits for the last token
logits = output[:, -1, :] / temperature
# Apply top-k and top-p sampling
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = -float('Inf')
# Sample the next token
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append the next token to the input sequence
input_ids = torch.cat([input_ids, next_token], dim=1)
# Stop if the end-of-sequence token is generated or max length is reached
if eos_token_id is not None and next_token.item() == eos_token_id:
break
return input_ids
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.bos_token = tokenizer.bos_token
tokenizer.eos_token = tokenizer.eos_token
tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
model = TransformerGenerator(
vocab_size=vocab_size,
embed_size=768,
num_layers=3,
num_heads=8,
ff_dim=768*4,
max_len=512
)
model.to(device)
model_path = hf_hub_download(repo_id=model_name, filename="pytorch_model.bin")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
def text(prompt):
model.eval()
with torch.no_grad():
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate text
output = model.generate(input_ids, max_length=100, temperature=0.7, eos_token_id=tokenizer.eos_token_id)
# Decode and print the generated text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)
# Example usage
text("Once upon a time")
- Downloads last month
- -
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support