|
|
""" |
|
|
Mini-Transformer Embedding Model |
|
|
==================================== |
|
|
A lightweight transformer encoder for generating text embeddings. |
|
|
Built from scratch using PyTorch. |
|
|
|
|
|
Architecture: |
|
|
- Token Embeddings + Sinusoidal Positional Encoding |
|
|
- N Transformer Encoder Layers (Pre-LayerNorm) |
|
|
- Multi-Head Self-Attention |
|
|
- Position-wise Feed-Forward Networks |
|
|
- Mean Pooling + L2 Normalization |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
class SinusoidalPositionalEncoding(nn.Module): |
|
|
""" |
|
|
Sinusoidal positional encoding from "Attention Is All You Need". |
|
|
|
|
|
Adds position information to token embeddings using sin/cos functions |
|
|
at different frequencies, allowing the model to understand token order. |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int, max_seq_len: int = 512, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
|
|
|
|
|
pe = torch.zeros(max_seq_len, d_model) |
|
|
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) |
|
|
|
|
|
|
|
|
div_term = torch.exp( |
|
|
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
|
|
) |
|
|
|
|
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
|
|
|
|
|
|
pe = pe.unsqueeze(0) |
|
|
self.register_buffer('pe', pe) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: Tensor of shape [batch_size, seq_len, d_model] |
|
|
Returns: |
|
|
Tensor with positional encoding added |
|
|
""" |
|
|
x = x + self.pe[:, :x.size(1), :] |
|
|
return self.dropout(x) |
|
|
|
|
|
|
|
|
class MultiHeadSelfAttention(nn.Module): |
|
|
""" |
|
|
Multi-Head Self-Attention mechanism. |
|
|
|
|
|
Allows the model to jointly attend to information from different |
|
|
representation subspaces at different positions. |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
assert d_model % num_heads == 0, "d_model must be divisible by num_heads" |
|
|
|
|
|
self.d_model = d_model |
|
|
self.num_heads = num_heads |
|
|
self.d_k = d_model // num_heads |
|
|
|
|
|
|
|
|
self.W_q = nn.Linear(d_model, d_model) |
|
|
self.W_k = nn.Linear(d_model, d_model) |
|
|
self.W_v = nn.Linear(d_model, d_model) |
|
|
|
|
|
|
|
|
self.W_o = nn.Linear(d_model, d_model) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.scale = math.sqrt(self.d_k) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: Input tensor [batch_size, seq_len, d_model] |
|
|
attention_mask: Optional mask [batch_size, seq_len] |
|
|
Returns: |
|
|
Output tensor [batch_size, seq_len, d_model] |
|
|
""" |
|
|
batch_size, seq_len, _ = x.size() |
|
|
|
|
|
|
|
|
Q = self.W_q(x) |
|
|
K = self.W_k(x) |
|
|
V = self.W_v(x) |
|
|
|
|
|
|
|
|
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) |
|
|
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) |
|
|
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) |
|
|
|
|
|
|
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale |
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
scores = scores.masked_fill(mask == 0, float('-inf')) |
|
|
|
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1) |
|
|
attn_weights = self.dropout(attn_weights) |
|
|
|
|
|
|
|
|
context = torch.matmul(attn_weights, V) |
|
|
|
|
|
|
|
|
|
|
|
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) |
|
|
|
|
|
|
|
|
output = self.W_o(context) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
|
""" |
|
|
Position-wise Feed-Forward Network. |
|
|
|
|
|
Two linear transformations with a GELU activation in between. |
|
|
Applied to each position separately and identically. |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.linear1 = nn.Linear(d_model, d_ff) |
|
|
self.linear2 = nn.Linear(d_ff, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: Input tensor [batch_size, seq_len, d_model] |
|
|
Returns: |
|
|
Output tensor [batch_size, seq_len, d_model] |
|
|
""" |
|
|
x = self.linear1(x) |
|
|
x = F.gelu(x) |
|
|
x = self.dropout(x) |
|
|
x = self.linear2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
|
""" |
|
|
Single Transformer Encoder Layer with Pre-LayerNorm. |
|
|
|
|
|
Components: |
|
|
1. Multi-Head Self-Attention with residual connection |
|
|
2. Position-wise Feed-Forward with residual connection |
|
|
|
|
|
Uses Pre-LayerNorm for better training stability. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
num_heads: int, |
|
|
d_ff: int, |
|
|
dropout: float = 0.1 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self.attention = MultiHeadSelfAttention(d_model, num_heads, dropout) |
|
|
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: Input tensor [batch_size, seq_len, d_model] |
|
|
attention_mask: Optional mask [batch_size, seq_len] |
|
|
Returns: |
|
|
Output tensor [batch_size, seq_len, d_model] |
|
|
""" |
|
|
|
|
|
normed = self.norm1(x) |
|
|
attn_output = self.attention(normed, attention_mask) |
|
|
x = x + self.dropout(attn_output) |
|
|
|
|
|
|
|
|
normed = self.norm2(x) |
|
|
ff_output = self.feed_forward(normed) |
|
|
x = x + self.dropout(ff_output) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class MiniTransformerEmbedding(nn.Module): |
|
|
""" |
|
|
Mini-Transformer Embedding Model. |
|
|
|
|
|
Converts variable-length text sequences into fixed-size dense vectors |
|
|
suitable for semantic similarity, search, and clustering tasks. |
|
|
|
|
|
Architecture: |
|
|
1. Token Embedding Layer (vocab → d_model) |
|
|
2. Sinusoidal Positional Encoding |
|
|
3. N Transformer Encoder Layers |
|
|
4. Mean Pooling (sequence → single vector) |
|
|
5. L2 Normalization (for cosine similarity) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 30000, |
|
|
d_model: int = 256, |
|
|
num_heads: int = 4, |
|
|
num_layers: int = 4, |
|
|
d_ff: int = 1024, |
|
|
max_seq_len: int = 128, |
|
|
dropout: float = 0.1, |
|
|
pad_token_id: int = 0 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
self.pad_token_id = pad_token_id |
|
|
|
|
|
|
|
|
self.token_embedding = nn.Embedding( |
|
|
vocab_size, d_model, padding_idx=pad_token_id |
|
|
) |
|
|
|
|
|
|
|
|
self.positional_encoding = SinusoidalPositionalEncoding( |
|
|
d_model, max_seq_len, dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
TransformerEncoderLayer(d_model, num_heads, d_ff, dropout) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.final_norm = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
"""Initialize weights using Xavier/Glorot initialization.""" |
|
|
for module in self.modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
nn.init.normal_(module.weight, mean=0, std=0.02) |
|
|
if module.padding_idx is not None: |
|
|
nn.init.zeros_(module.weight[module.padding_idx]) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the encoder. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch_size, seq_len] |
|
|
attention_mask: Mask for padding [batch_size, seq_len] |
|
|
|
|
|
Returns: |
|
|
Token-level representations [batch_size, seq_len, d_model] |
|
|
""" |
|
|
|
|
|
x = self.token_embedding(input_ids) * math.sqrt(self.d_model) |
|
|
|
|
|
|
|
|
x = self.positional_encoding(x) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x, attention_mask) |
|
|
|
|
|
|
|
|
x = self.final_norm(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Encode input tokens to a single embedding vector per sequence. |
|
|
|
|
|
Uses mean pooling over non-padded tokens, followed by L2 normalization. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch_size, seq_len] |
|
|
attention_mask: Mask for padding [batch_size, seq_len] |
|
|
|
|
|
Returns: |
|
|
Normalized embeddings [batch_size, d_model] |
|
|
""" |
|
|
|
|
|
token_embeddings = self.forward(input_ids, attention_mask) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
mask_expanded = attention_mask.unsqueeze(-1).float() |
|
|
|
|
|
|
|
|
sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1) |
|
|
|
|
|
|
|
|
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9) |
|
|
|
|
|
|
|
|
embeddings = sum_embeddings / sum_mask |
|
|
else: |
|
|
|
|
|
embeddings = torch.mean(token_embeddings, dim=1) |
|
|
|
|
|
|
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
|
|
return embeddings |
|
|
|