OpenGPT / model /layers.py
VolodymyrPugachov's picture
Upload 17 files
6810eb1 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, hidden_dim: int, n_heads: int, dropout: float):
super().__init__()
assert hidden_dim % n_heads == 0, "hidden_dim must be divisible by n_heads"
self.n_heads = n_heads
self.head_dim = hidden_dim // n_heads
# Linear projection for query, key, value (combined for efficiency)
self.qkv = nn.Linear(hidden_dim, hidden_dim * 3)
# Linear projection for output
self.out_proj = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x shape: (batch, seq_length, hidden_dim)
batch_size, seq_length, hidden_dim = x.size()
# Project queries, keys, and values
qkv = self.qkv(x) # (batch, seq_length, 3*hidden_dim)
# Split into Q, K, V and reshape for multi-head attention
qkv = qkv.reshape(batch_size, seq_length, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, n_heads, seq_length, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2] # Each shape: (batch, n_heads, seq_length, head_dim)
# Scaled dot-product attention
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Causal mask to hide future positions
mask = torch.tril(torch.ones(seq_length, seq_length, device=x.device))
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1) # (batch, n_heads, seq_length, seq_length)
attn_weights = self.dropout(attn_weights)
# Weighted sum of values
attn_output = torch.matmul(attn_weights, v) # (batch, n_heads, seq_length, head_dim)
# Combine heads
attn_output = attn_output.permute(0, 2, 1, 4).reshape(batch_size, seq_length, hidden_dim)
# Final linear projection and dropout
output = self.out_proj(attn_output)
output = self.dropout(output)
return output
class FeedForward(nn.Module):
def __init__(self, hidden_dim: int, dropout: float):
super().__init__()
self.fc1 = nn.Linear(hidden_dim, 4 * hidden_dim)
self.fc2 = nn.Linear(4 * hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Two-layer MLP with GELU activation
x = F.gelu(self.fc1(x))
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, hidden_dim: int, n_heads: int, dropout: float):
super().__init__()
self.ln1 = nn.LayerNorm(hidden_dim)
self.ln2 = nn.LayerNorm(hidden_dim)
self.attn = SelfAttention(hidden_dim, n_heads, dropout)
self.ff = FeedForward(hidden_dim, dropout)
def forward(self, x):
# Apply self-attention with residual connection
a = self.ln1(x)
x = x + self.attn(a)
# Apply feed-forward network with residual connection
m = self.ln2(x)
x = x + self.ff(m)
return x