BitPixelLM / model /text_encoder.py
BlakePeavy's picture
Upload BitPixelLM model artifacts
72e872c verified
"""
PixelArtGen — Text Encoder
A small transformer encoder that converts text prompts into
contextual embeddings for conditioning the pixel art decoder.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
class TextTokenizer:
"""Simple word-level tokenizer for text prompts."""
def __init__(self, vocab: List[str]):
self.word2idx = {w: i for i, w in enumerate(vocab)}
self.idx2word = {i: w for i, w in enumerate(vocab)}
self.pad_idx = self.word2idx.get("<pad>", 0)
self.sos_idx = self.word2idx.get("<sos>", 1)
self.eos_idx = self.word2idx.get("<eos>", 2)
self.unk_idx = self.word2idx.get("<unk>", 3)
self.vocab_size = len(vocab)
def encode(self, text: str, max_len: int = 32) -> torch.Tensor:
"""Tokenize and pad a text prompt."""
words = text.lower().strip().split()
tokens = [self.sos_idx]
for w in words:
tokens.append(self.word2idx.get(w, self.unk_idx))
tokens.append(self.eos_idx)
# Pad or truncate
if len(tokens) > max_len:
tokens = tokens[:max_len]
else:
tokens += [self.pad_idx] * (max_len - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
def encode_batch(self, texts: List[str], max_len: int = 32) -> torch.Tensor:
"""Encode a batch of text prompts."""
return torch.stack([self.encode(t, max_len) for t in texts])
class TextEncoder(nn.Module):
"""
Small transformer encoder for text prompts.
Architecture:
- Word embeddings + sinusoidal positional encoding
- N transformer encoder layers with multi-head attention
- Output: sequence of contextual embeddings (batch, seq_len, d_model)
"""
def __init__(
self,
vocab_size: int,
d_model: int = 256,
nhead: int = 4,
num_layers: int = 3,
dim_feedforward: int = 512,
max_seq_len: int = 32,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len)
self.dropout = nn.Dropout(dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
norm_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.norm = nn.LayerNorm(d_model)
def forward(self, text_tokens: torch.Tensor) -> torch.Tensor:
"""
Args:
text_tokens: (batch, seq_len) long tensor of word indices
Returns:
(batch, seq_len, d_model) contextual embeddings
"""
# Create padding mask (True = ignore)
pad_mask = (text_tokens == 0) # pad_idx = 0
# Embed + positional encode
x = self.embedding(text_tokens) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
x = self.dropout(x)
# Transformer encode
x = self.transformer(x, src_key_padding_mask=pad_mask)
x = self.norm(x)
return x
class SinusoidalPositionalEncoding(nn.Module):
"""Standard sinusoidal positional encoding."""
def __init__(self, d_model: int, max_len: int = 512):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.pe[:, :x.size(1)]