text2sign / models /text_encoder.py
xiaruize's picture
upd
234a70c
"""
Text encoder for conditioning the diffusion model
Uses a simple transformer architecture
"""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding"""
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.pe[:, :x.size(1)]
class TransformerEncoderLayer(nn.Module):
"""Single transformer encoder layer"""
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Self attention
x2, _ = self.self_attn(x, x, x, key_padding_mask=mask)
x = x + self.dropout1(x2)
x = self.norm1(x)
# Feed forward
x2 = self.linear2(self.dropout(F.gelu(self.linear1(x))))
x = x + self.dropout2(x2)
x = self.norm2(x)
return x
class TextEncoder(nn.Module):
"""
Transformer-based text encoder for conditioning
Similar to CLIP text encoder but simplified
"""
def __init__(
self,
vocab_size: int = 49408,
max_length: int = 77,
embed_dim: int = 512,
num_layers: int = 6,
num_heads: int = 8,
dropout: float = 0.1,
):
super().__init__()
self.vocab_size = vocab_size
self.max_length = max_length
self.embed_dim = embed_dim
# Token embedding
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
# Positional encoding
self.pos_encoding = PositionalEncoding(embed_dim, max_length)
# Transformer layers
self.layers = nn.ModuleList([
TransformerEncoderLayer(
d_model=embed_dim,
num_heads=num_heads,
dim_feedforward=embed_dim * 4,
dropout=dropout,
)
for _ in range(num_layers)
])
# Final layer norm
self.final_norm = nn.LayerNorm(embed_dim)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize weights"""
nn.init.normal_(self.token_embedding.weight, std=0.02)
def forward(
self,
tokens: torch.Tensor, # (B, seq_len)
return_pooled: bool = False,
) -> torch.Tensor:
"""
Forward pass
Args:
tokens: Token IDs (B, seq_len)
return_pooled: Whether to return pooled output (first token)
Returns:
Text embeddings (B, seq_len, embed_dim) or (B, embed_dim) if pooled
"""
# Token embedding
x = self.token_embedding(tokens) # (B, seq_len, embed_dim)
# Add positional encoding
x = self.pos_encoding(x)
# Create attention mask for padding (token_id == 2)
padding_mask = (tokens == 2) # pad_token_id = 2
# Transformer layers
for layer in self.layers:
x = layer(x, mask=padding_mask)
# Final norm
x = self.final_norm(x)
if return_pooled:
# Return first token embedding (like [CLS])
return x[:, 0]
return x
class FrozenCLIPTextEncoder(nn.Module):
"""
Wrapper for using pretrained CLIP text encoder (if available)
Falls back to custom TextEncoder if CLIP is not available
"""
def __init__(
self,
embed_dim: int = 512,
max_length: int = 77,
):
super().__init__()
self.embed_dim = embed_dim
self.max_length = max_length
try:
from transformers import CLIPTextModel, CLIPTokenizer
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
self.model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
# Freeze the model
for param in self.model.parameters():
param.requires_grad = False
# Project to target dim if needed
clip_dim = self.model.config.hidden_size
if clip_dim != embed_dim:
self.proj = nn.Linear(clip_dim, embed_dim)
else:
self.proj = nn.Identity()
self.use_clip = True
print("Using pretrained CLIP text encoder")
except Exception as e:
print(f"CLIP not available ({e}), using custom text encoder")
self.model = TextEncoder(
embed_dim=embed_dim,
max_length=max_length,
)
self.proj = nn.Identity()
self.use_clip = False
def forward(
self,
tokens: torch.Tensor,
text: Optional[list] = None,
) -> torch.Tensor:
"""
Forward pass
Args:
tokens: Pre-tokenized token IDs (B, seq_len) - used if not using CLIP
text: List of text strings - used if using CLIP
Returns:
Text embeddings (B, seq_len, embed_dim)
"""
if self.use_clip and text is not None:
# Tokenize with CLIP tokenizer
inputs = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
hidden_states = outputs.last_hidden_state
return self.proj(hidden_states)
else:
return self.proj(self.model(tokens))
def create_text_encoder(config, use_clip: bool = True):
"""Create text encoder from config (default: pretrained CLIP)"""
if use_clip:
return FrozenCLIPTextEncoder(
embed_dim=config.text_embed_dim,
max_length=config.max_text_length,
)
else:
return TextEncoder(
vocab_size=config.vocab_size,
max_length=config.max_text_length,
embed_dim=config.text_embed_dim,
)
if __name__ == "__main__":
# Test the encoder
encoder = TextEncoder(
vocab_size=49408,
max_length=77,
embed_dim=512,
num_layers=6,
num_heads=8,
)
# Test input
tokens = torch.randint(0, 49408, (2, 77))
# Forward pass
output = encoder(tokens)
print(f"Input shape: {tokens.shape}")
print(f"Output shape: {output.shape}")
print(f"Parameters: {sum(p.numel() for p in encoder.parameters()):,}")