| """
|
| PixelArtGen — Text Encoder
|
|
|
| A small transformer encoder that converts text prompts into
|
| contextual embeddings for conditioning the pixel art decoder.
|
| """
|
|
|
| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import List
|
|
|
|
|
| class TextTokenizer:
|
| """Simple word-level tokenizer for text prompts."""
|
|
|
| def __init__(self, vocab: List[str]):
|
| self.word2idx = {w: i for i, w in enumerate(vocab)}
|
| self.idx2word = {i: w for i, w in enumerate(vocab)}
|
| self.pad_idx = self.word2idx.get("<pad>", 0)
|
| self.sos_idx = self.word2idx.get("<sos>", 1)
|
| self.eos_idx = self.word2idx.get("<eos>", 2)
|
| self.unk_idx = self.word2idx.get("<unk>", 3)
|
| self.vocab_size = len(vocab)
|
|
|
| def encode(self, text: str, max_len: int = 32) -> torch.Tensor:
|
| """Tokenize and pad a text prompt."""
|
| words = text.lower().strip().split()
|
| tokens = [self.sos_idx]
|
| for w in words:
|
| tokens.append(self.word2idx.get(w, self.unk_idx))
|
| tokens.append(self.eos_idx)
|
|
|
|
|
| if len(tokens) > max_len:
|
| tokens = tokens[:max_len]
|
| else:
|
| tokens += [self.pad_idx] * (max_len - len(tokens))
|
|
|
| return torch.tensor(tokens, dtype=torch.long)
|
|
|
| def encode_batch(self, texts: List[str], max_len: int = 32) -> torch.Tensor:
|
| """Encode a batch of text prompts."""
|
| return torch.stack([self.encode(t, max_len) for t in texts])
|
|
|
|
|
| class TextEncoder(nn.Module):
|
| """
|
| Small transformer encoder for text prompts.
|
|
|
| Architecture:
|
| - Word embeddings + sinusoidal positional encoding
|
| - N transformer encoder layers with multi-head attention
|
| - Output: sequence of contextual embeddings (batch, seq_len, d_model)
|
| """
|
|
|
| def __init__(
|
| self,
|
| vocab_size: int,
|
| d_model: int = 256,
|
| nhead: int = 4,
|
| num_layers: int = 3,
|
| dim_feedforward: int = 512,
|
| max_seq_len: int = 32,
|
| dropout: float = 0.1,
|
| ):
|
| super().__init__()
|
| self.d_model = d_model
|
| self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
|
| self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len)
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| encoder_layer = nn.TransformerEncoderLayer(
|
| d_model=d_model,
|
| nhead=nhead,
|
| dim_feedforward=dim_feedforward,
|
| dropout=dropout,
|
| batch_first=True,
|
| norm_first=True,
|
| )
|
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| self.norm = nn.LayerNorm(d_model)
|
|
|
| def forward(self, text_tokens: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| text_tokens: (batch, seq_len) long tensor of word indices
|
| Returns:
|
| (batch, seq_len, d_model) contextual embeddings
|
| """
|
|
|
| pad_mask = (text_tokens == 0)
|
|
|
|
|
| x = self.embedding(text_tokens) * math.sqrt(self.d_model)
|
| x = self.pos_encoding(x)
|
| x = self.dropout(x)
|
|
|
|
|
| x = self.transformer(x, src_key_padding_mask=pad_mask)
|
| x = self.norm(x)
|
|
|
| return x
|
|
|
|
|
| class SinusoidalPositionalEncoding(nn.Module):
|
| """Standard sinusoidal positional encoding."""
|
|
|
| def __init__(self, d_model: int, max_len: int = 512):
|
| super().__init__()
|
| pe = torch.zeros(max_len, d_model)
|
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
|
|
| pe[:, 0::2] = torch.sin(position * div_term)
|
| pe[:, 1::2] = torch.cos(position * div_term)
|
| pe = pe.unsqueeze(0)
|
| self.register_buffer("pe", pe)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return x + self.pe[:, :x.size(1)]
|
|
|