NGen2-170M / model.py
Thishyaketh's picture
Upload 6 files
0c9d49d verified
import torch
from torch import nn
from transformers import GPT2Tokenizer
from datasets import load_dataset
# Load GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Add a padding token if it doesn't exist
tokenizer.pad_token = tokenizer.eos_token
# Define the ETGAA model architecture
class ETGAA(nn.Module):
def __init__(self, vocab_size, hidden_size, num_layers, num_heads, intermediate_size, graph_attention_heads, adaptive_embedding_size):
super(ETGAA, self).__init__()
# Embedding layer
self.embedding = nn.Embedding(vocab_size, hidden_size)
# Transformer layers
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=intermediate_size,
activation='gelu'
) for _ in range(num_layers)
])
# Graph Attention Network layer
self.graph_attention = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=graph_attention_heads
)
# Adaptive Embedding Layer
self.adaptive_embedding = nn.Linear(hidden_size, adaptive_embedding_size)
# Layer Normalization
self.layer_norm = nn.LayerNorm(hidden_size)
# Output Layer
self.fc_out = nn.Linear(hidden_size, vocab_size)
def forward(self, input_ids, attention_mask=None):
# Embedding
embeddings = self.embedding(input_ids)
# Transformer layers
x = embeddings
for layer in self.layers:
x = layer(x)
# Graph Attention
x, _ = self.graph_attention(x, x, x, need_weights=False)
# Adaptive Embedding
x = self.adaptive_embedding(x)
# Layer Norm
x = self.layer_norm(x)
# Output layer
logits = self.fc_out(x)
return logits
# Initialize model
model = ETGAA(
vocab_size=tokenizer.vocab_size,
hidden_size=768,
num_layers=24,
num_heads=12,
intermediate_size=3072,
graph_attention_heads=8,
adaptive_embedding_size=512
)
# Load dataset and prepare data loader
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train[:5%]")
def encode(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=1024)
dataset = dataset.map(encode, batched=True)
dataset.set_format(type="torch", columns=["input_ids"])
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for step, batch in enumerate(train_dataloader):
optimizer.zero_grad()
input_ids = batch["input_ids"].to(torch.long)
outputs = model(input_ids)
loss = nn.CrossEntropyLoss()(outputs.view(-1, tokenizer.vocab_size), input_ids.view(-1))
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f"Step {step}: Loss {loss.item()}")
if step > 500000:
break
# Save the model
model.save_pretrained("./etgaa_model")
tokenizer.save_pretrained("./etgaa_model")