""" PixelArtGen — Text Encoder A small transformer encoder that converts text prompts into contextual embeddings for conditioning the pixel art decoder. """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import List class TextTokenizer: """Simple word-level tokenizer for text prompts.""" def __init__(self, vocab: List[str]): self.word2idx = {w: i for i, w in enumerate(vocab)} self.idx2word = {i: w for i, w in enumerate(vocab)} self.pad_idx = self.word2idx.get("", 0) self.sos_idx = self.word2idx.get("", 1) self.eos_idx = self.word2idx.get("", 2) self.unk_idx = self.word2idx.get("", 3) self.vocab_size = len(vocab) def encode(self, text: str, max_len: int = 32) -> torch.Tensor: """Tokenize and pad a text prompt.""" words = text.lower().strip().split() tokens = [self.sos_idx] for w in words: tokens.append(self.word2idx.get(w, self.unk_idx)) tokens.append(self.eos_idx) # Pad or truncate if len(tokens) > max_len: tokens = tokens[:max_len] else: tokens += [self.pad_idx] * (max_len - len(tokens)) return torch.tensor(tokens, dtype=torch.long) def encode_batch(self, texts: List[str], max_len: int = 32) -> torch.Tensor: """Encode a batch of text prompts.""" return torch.stack([self.encode(t, max_len) for t in texts]) class TextEncoder(nn.Module): """ Small transformer encoder for text prompts. Architecture: - Word embeddings + sinusoidal positional encoding - N transformer encoder layers with multi-head attention - Output: sequence of contextual embeddings (batch, seq_len, d_model) """ def __init__( self, vocab_size: int, d_model: int = 256, nhead: int = 4, num_layers: int = 3, dim_feedforward: int = 512, max_seq_len: int = 32, dropout: float = 0.1, ): super().__init__() self.d_model = d_model self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len) self.dropout = nn.Dropout(dropout) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True, norm_first=True, ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.norm = nn.LayerNorm(d_model) def forward(self, text_tokens: torch.Tensor) -> torch.Tensor: """ Args: text_tokens: (batch, seq_len) long tensor of word indices Returns: (batch, seq_len, d_model) contextual embeddings """ # Create padding mask (True = ignore) pad_mask = (text_tokens == 0) # pad_idx = 0 # Embed + positional encode x = self.embedding(text_tokens) * math.sqrt(self.d_model) x = self.pos_encoding(x) x = self.dropout(x) # Transformer encode x = self.transformer(x, src_key_padding_mask=pad_mask) x = self.norm(x) return x class SinusoidalPositionalEncoding(nn.Module): """Standard sinusoidal positional encoding.""" def __init__(self, d_model: int, max_len: int = 512): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer("pe", pe) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.pe[:, :x.size(1)]