|
|
""" |
|
|
Text encoder for conditioning the diffusion model |
|
|
Uses a simple transformer architecture |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
"""Sinusoidal positional encoding""" |
|
|
def __init__(self, d_model: int, max_len: int = 5000): |
|
|
super().__init__() |
|
|
|
|
|
pe = torch.zeros(max_len, d_model) |
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
pe = pe.unsqueeze(0) |
|
|
|
|
|
self.register_buffer('pe', pe) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x + self.pe[:, :x.size(1)] |
|
|
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
|
"""Single transformer encoder layer""" |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
num_heads: int, |
|
|
dim_feedforward: int = 2048, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.self_attn = nn.MultiheadAttention( |
|
|
d_model, num_heads, dropout=dropout, batch_first=True |
|
|
) |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
x2, _ = self.self_attn(x, x, x, key_padding_mask=mask) |
|
|
x = x + self.dropout1(x2) |
|
|
x = self.norm1(x) |
|
|
|
|
|
|
|
|
x2 = self.linear2(self.dropout(F.gelu(self.linear1(x)))) |
|
|
x = x + self.dropout2(x2) |
|
|
x = self.norm2(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class TextEncoder(nn.Module): |
|
|
""" |
|
|
Transformer-based text encoder for conditioning |
|
|
Similar to CLIP text encoder but simplified |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 49408, |
|
|
max_length: int = 77, |
|
|
embed_dim: int = 512, |
|
|
num_layers: int = 6, |
|
|
num_heads: int = 8, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.max_length = max_length |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
|
|
|
self.token_embedding = nn.Embedding(vocab_size, embed_dim) |
|
|
|
|
|
|
|
|
self.pos_encoding = PositionalEncoding(embed_dim, max_length) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
TransformerEncoderLayer( |
|
|
d_model=embed_dim, |
|
|
num_heads=num_heads, |
|
|
dim_feedforward=embed_dim * 4, |
|
|
dropout=dropout, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.final_norm = nn.LayerNorm(embed_dim) |
|
|
|
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
"""Initialize weights""" |
|
|
nn.init.normal_(self.token_embedding.weight, std=0.02) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tokens: torch.Tensor, |
|
|
return_pooled: bool = False, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass |
|
|
Args: |
|
|
tokens: Token IDs (B, seq_len) |
|
|
return_pooled: Whether to return pooled output (first token) |
|
|
Returns: |
|
|
Text embeddings (B, seq_len, embed_dim) or (B, embed_dim) if pooled |
|
|
""" |
|
|
|
|
|
x = self.token_embedding(tokens) |
|
|
|
|
|
|
|
|
x = self.pos_encoding(x) |
|
|
|
|
|
|
|
|
padding_mask = (tokens == 2) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x, mask=padding_mask) |
|
|
|
|
|
|
|
|
x = self.final_norm(x) |
|
|
|
|
|
if return_pooled: |
|
|
|
|
|
return x[:, 0] |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class FrozenCLIPTextEncoder(nn.Module): |
|
|
""" |
|
|
Wrapper for using pretrained CLIP text encoder (if available) |
|
|
Falls back to custom TextEncoder if CLIP is not available |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int = 512, |
|
|
max_length: int = 77, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.embed_dim = embed_dim |
|
|
self.max_length = max_length |
|
|
|
|
|
try: |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
self.model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
|
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
clip_dim = self.model.config.hidden_size |
|
|
if clip_dim != embed_dim: |
|
|
self.proj = nn.Linear(clip_dim, embed_dim) |
|
|
else: |
|
|
self.proj = nn.Identity() |
|
|
|
|
|
self.use_clip = True |
|
|
print("Using pretrained CLIP text encoder") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"CLIP not available ({e}), using custom text encoder") |
|
|
self.model = TextEncoder( |
|
|
embed_dim=embed_dim, |
|
|
max_length=max_length, |
|
|
) |
|
|
self.proj = nn.Identity() |
|
|
self.use_clip = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tokens: torch.Tensor, |
|
|
text: Optional[list] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass |
|
|
Args: |
|
|
tokens: Pre-tokenized token IDs (B, seq_len) - used if not using CLIP |
|
|
text: List of text strings - used if using CLIP |
|
|
Returns: |
|
|
Text embeddings (B, seq_len, embed_dim) |
|
|
""" |
|
|
if self.use_clip and text is not None: |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
text, |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
return self.proj(hidden_states) |
|
|
else: |
|
|
return self.proj(self.model(tokens)) |
|
|
|
|
|
|
|
|
def create_text_encoder(config, use_clip: bool = True): |
|
|
"""Create text encoder from config (default: pretrained CLIP)""" |
|
|
if use_clip: |
|
|
return FrozenCLIPTextEncoder( |
|
|
embed_dim=config.text_embed_dim, |
|
|
max_length=config.max_text_length, |
|
|
) |
|
|
else: |
|
|
return TextEncoder( |
|
|
vocab_size=config.vocab_size, |
|
|
max_length=config.max_text_length, |
|
|
embed_dim=config.text_embed_dim, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
encoder = TextEncoder( |
|
|
vocab_size=49408, |
|
|
max_length=77, |
|
|
embed_dim=512, |
|
|
num_layers=6, |
|
|
num_heads=8, |
|
|
) |
|
|
|
|
|
|
|
|
tokens = torch.randint(0, 49408, (2, 77)) |
|
|
|
|
|
|
|
|
output = encoder(tokens) |
|
|
print(f"Input shape: {tokens.shape}") |
|
|
print(f"Output shape: {output.shape}") |
|
|
print(f"Parameters: {sum(p.numel() for p in encoder.parameters()):,}") |
|
|
|