|
|
import torch
|
|
|
from torch import nn
|
|
|
from transformers import GPT2Tokenizer
|
|
|
from datasets import load_dataset
|
|
|
|
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
|
|
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
class ETGAA(nn.Module):
|
|
|
def __init__(self, vocab_size, hidden_size, num_layers, num_heads, intermediate_size, graph_attention_heads, adaptive_embedding_size):
|
|
|
super(ETGAA, self).__init__()
|
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([
|
|
|
nn.TransformerEncoderLayer(
|
|
|
d_model=hidden_size,
|
|
|
nhead=num_heads,
|
|
|
dim_feedforward=intermediate_size,
|
|
|
activation='gelu'
|
|
|
) for _ in range(num_layers)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.graph_attention = nn.MultiheadAttention(
|
|
|
embed_dim=hidden_size,
|
|
|
num_heads=graph_attention_heads
|
|
|
)
|
|
|
|
|
|
|
|
|
self.adaptive_embedding = nn.Linear(hidden_size, adaptive_embedding_size)
|
|
|
|
|
|
|
|
|
self.layer_norm = nn.LayerNorm(hidden_size)
|
|
|
|
|
|
|
|
|
self.fc_out = nn.Linear(hidden_size, vocab_size)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask=None):
|
|
|
|
|
|
embeddings = self.embedding(input_ids)
|
|
|
|
|
|
|
|
|
x = embeddings
|
|
|
for layer in self.layers:
|
|
|
x = layer(x)
|
|
|
|
|
|
|
|
|
x, _ = self.graph_attention(x, x, x, need_weights=False)
|
|
|
|
|
|
|
|
|
x = self.adaptive_embedding(x)
|
|
|
|
|
|
|
|
|
x = self.layer_norm(x)
|
|
|
|
|
|
|
|
|
logits = self.fc_out(x)
|
|
|
return logits
|
|
|
|
|
|
|
|
|
model = ETGAA(
|
|
|
vocab_size=tokenizer.vocab_size,
|
|
|
hidden_size=768,
|
|
|
num_layers=24,
|
|
|
num_heads=12,
|
|
|
intermediate_size=3072,
|
|
|
graph_attention_heads=8,
|
|
|
adaptive_embedding_size=512
|
|
|
)
|
|
|
|
|
|
|
|
|
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train[:5%]")
|
|
|
|
|
|
def encode(examples):
|
|
|
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=1024)
|
|
|
|
|
|
dataset = dataset.map(encode, batched=True)
|
|
|
dataset.set_format(type="torch", columns=["input_ids"])
|
|
|
|
|
|
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
|
|
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
optimizer.zero_grad()
|
|
|
input_ids = batch["input_ids"].to(torch.long)
|
|
|
outputs = model(input_ids)
|
|
|
loss = nn.CrossEntropyLoss()(outputs.view(-1, tokenizer.vocab_size), input_ids.view(-1))
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
if step % 100 == 0:
|
|
|
print(f"Step {step}: Loss {loss.item()}")
|
|
|
|
|
|
if step > 500000:
|
|
|
break
|
|
|
|
|
|
|
|
|
model.save_pretrained("./etgaa_model")
|
|
|
tokenizer.save_pretrained("./etgaa_model")
|
|
|
|