Spaces:
Sleeping
Sleeping
| """ | |
| PixelArtGen — Text Encoder | |
| A small transformer encoder that converts text prompts into | |
| contextual embeddings for conditioning the pixel art decoder. | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import List | |
| class TextTokenizer: | |
| """Simple word-level tokenizer for text prompts.""" | |
| def __init__(self, vocab: List[str]): | |
| self.word2idx = {w: i for i, w in enumerate(vocab)} | |
| self.idx2word = {i: w for i, w in enumerate(vocab)} | |
| self.pad_idx = self.word2idx.get("<pad>", 0) | |
| self.sos_idx = self.word2idx.get("<sos>", 1) | |
| self.eos_idx = self.word2idx.get("<eos>", 2) | |
| self.unk_idx = self.word2idx.get("<unk>", 3) | |
| self.vocab_size = len(vocab) | |
| def encode(self, text: str, max_len: int = 32) -> torch.Tensor: | |
| """Tokenize and pad a text prompt.""" | |
| words = text.lower().strip().split() | |
| tokens = [self.sos_idx] | |
| for w in words: | |
| tokens.append(self.word2idx.get(w, self.unk_idx)) | |
| tokens.append(self.eos_idx) | |
| # Pad or truncate | |
| if len(tokens) > max_len: | |
| tokens = tokens[:max_len] | |
| else: | |
| tokens += [self.pad_idx] * (max_len - len(tokens)) | |
| return torch.tensor(tokens, dtype=torch.long) | |
| def encode_batch(self, texts: List[str], max_len: int = 32) -> torch.Tensor: | |
| """Encode a batch of text prompts.""" | |
| return torch.stack([self.encode(t, max_len) for t in texts]) | |
| class TextEncoder(nn.Module): | |
| """ | |
| Small transformer encoder for text prompts. | |
| Architecture: | |
| - Word embeddings + sinusoidal positional encoding | |
| - N transformer encoder layers with multi-head attention | |
| - Output: sequence of contextual embeddings (batch, seq_len, d_model) | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| d_model: int = 256, | |
| nhead: int = 4, | |
| num_layers: int = 3, | |
| dim_feedforward: int = 512, | |
| max_seq_len: int = 32, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) | |
| self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len) | |
| self.dropout = nn.Dropout(dropout) | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| batch_first=True, | |
| norm_first=True, | |
| ) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| self.norm = nn.LayerNorm(d_model) | |
| def forward(self, text_tokens: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| text_tokens: (batch, seq_len) long tensor of word indices | |
| Returns: | |
| (batch, seq_len, d_model) contextual embeddings | |
| """ | |
| # Create padding mask (True = ignore) | |
| pad_mask = (text_tokens == 0) # pad_idx = 0 | |
| # Embed + positional encode | |
| x = self.embedding(text_tokens) * math.sqrt(self.d_model) | |
| x = self.pos_encoding(x) | |
| x = self.dropout(x) | |
| # Transformer encode | |
| x = self.transformer(x, src_key_padding_mask=pad_mask) | |
| x = self.norm(x) | |
| return x | |
| class SinusoidalPositionalEncoding(nn.Module): | |
| """Standard sinusoidal positional encoding.""" | |
| def __init__(self, d_model: int, max_len: int = 512): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) # (1, max_len, d_model) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x + self.pe[:, :x.size(1)] | |