xiaruize commited on
Commit
234a70c
·
1 Parent(s): fb59797
config.py CHANGED
@@ -1,16 +1,101 @@
1
- # Example config for Text2Sign model
 
 
2
 
 
 
 
 
 
 
3
  class ModelConfig:
4
- vocab_size = 30522
5
- max_text_length = 77
6
- use_clip_text_encoder = False
7
- # ... other model hyperparameters ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class GenerationConfig:
10
- num_inference_steps = 50
11
- guidance_scale = 7.5
12
- eta = 0.0
13
- fps = 8
14
- # ... other generation settings ...
 
 
15
 
16
- # Add any additional config as needed for your model
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for Text-to-Sign Language DDIM Diffusion Model
3
+ """
4
 
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+ import torch
8
+
9
+
10
+ @dataclass
11
  class ModelConfig:
12
+ """Model architecture configuration"""
13
+ # Image/Video dimensions
14
+ image_size: int = 64 # Resize GIFs to 64x64
15
+ num_frames: int = 16 # Number of frames per video
16
+ in_channels: int = 3 # RGB channels
17
+
18
+ # UNet architecture (increased capacity for better quality)
19
+ model_channels: int = 96 # Increased from 64 for better quality
20
+ channel_mult: Tuple[int, ...] = (1, 2, 4) # Depth levels
21
+ num_res_blocks: int = 2
22
+ attention_resolutions: Tuple[int, ...] = (8, 16)
23
+ num_heads: int = 6 # Increased from 4 for better attention
24
+
25
+ # Transformer settings (DiT-style)
26
+ use_transformer: bool = True # Use enhanced DiT-style transformer blocks
27
+ transformer_depth: int = 2 # Increased from 1 for deeper transformers
28
+ use_gradient_checkpointing: bool = True # Enable gradient checkpointing for memory savings
29
+
30
+ # Text encoder
31
+ use_clip_text_encoder: bool = True # Default to frozen pretrained CLIP text encoder
32
+ text_embed_dim: int = 384 # Increased from 256 for richer text embeddings
33
+ max_text_length: int = 77
34
+ vocab_size: int = 49408 # CLIP vocab size
35
+
36
+ # Cross attention
37
+ context_dim: int = 384 # Increased from 256 for better cross-attention
38
+
39
+
40
+ @dataclass
41
+ class DDIMConfig:
42
+ """DDIM scheduler configuration"""
43
+ num_train_timesteps: int = 100
44
+ num_inference_steps: int = 100
45
+ beta_start: float = 0.0001
46
+ beta_end: float = 0.02
47
+ beta_schedule: str = "linear" # "linear" or "cosine"
48
+ clip_sample: bool = True
49
+ prediction_type: str = "epsilon" # "epsilon" or "v_prediction"
50
 
51
+
52
+ @dataclass
53
+ class TrainingConfig:
54
+ """Training configuration"""
55
+ # Data
56
+ data_dir: str = "text2sign/training_data"
57
+ batch_size: int = 2 # Reduced from 4 for memory
58
+ num_workers: int = 4
59
+
60
+ # Training
61
+ num_epochs: int = 150 # Increased for more training
62
+ learning_rate: float = 5e-5 # Reduced from 1e-4 for fine-tuning stability
63
+ weight_decay: float = 0.01
64
+ warmup_steps: int = 500 # Reduced warmup for fine-tuning
65
+ gradient_accumulation_steps: int = 8 # Effective batch size = 16
66
+ max_grad_norm: float = 1.0
67
+
68
+ # Mixed precision
69
+ use_amp: bool = True
70
+
71
+ # Checkpointing
72
+ checkpoint_dir: str = "text_to_sign/checkpoints"
73
+ save_every: int = 5 # Save every N epochs
74
+ log_every: int = 100 # Log every N steps
75
+ sample_every: int = 1000 # Generate samples every N steps
76
+
77
+ # TensorBoard
78
+ log_dir: str = "text_to_sign/logs"
79
+
80
+ # Device
81
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
82
+
83
+
84
+ @dataclass
85
  class GenerationConfig:
86
+ """Generation/Inference configuration"""
87
+ num_inference_steps: int = 50
88
+ guidance_scale: float = 7.5
89
+ eta: float = 0.0 # 0 for DDIM, 1 for DDPM
90
+ output_dir: str = "text_to_sign/generated"
91
+ fps: int = 8 # Output GIF frame rate
92
+
93
 
94
+ def get_config():
95
+ """Get all configurations"""
96
+ return {
97
+ "model": ModelConfig(),
98
+ "ddim": DDIMConfig(),
99
+ "training": TrainingConfig(),
100
+ "generation": GenerationConfig(),
101
+ }
dataset.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset for loading text-GIF pairs for sign language generation
3
+ """
4
+
5
+ import os
6
+ import glob
7
+ import random
8
+ from typing import Dict, List, Optional, Tuple
9
+
10
+ import torch
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from PIL import Image
13
+ import numpy as np
14
+ from torchvision import transforms
15
+
16
+
17
+ class SignLanguageDataset(Dataset):
18
+ """Dataset for text-to-sign language video generation"""
19
+
20
+ def __init__(
21
+ self,
22
+ data_dir: str,
23
+ image_size: int = 64,
24
+ num_frames: int = 16,
25
+ train: bool = True,
26
+ train_ratio: float = 0.9,
27
+ ):
28
+ """
29
+ Args:
30
+ data_dir: Directory containing .gif and .txt files
31
+ image_size: Size to resize frames to
32
+ num_frames: Number of frames to sample from each GIF
33
+ train: Whether this is training set
34
+ train_ratio: Ratio of data to use for training
35
+ """
36
+ self.data_dir = data_dir
37
+ self.image_size = image_size
38
+ self.num_frames = num_frames
39
+ self.train = train
40
+
41
+ # Find all pairs
42
+ self.pairs = self._find_pairs()
43
+
44
+ # Split into train/val
45
+ random.seed(42)
46
+ indices = list(range(len(self.pairs)))
47
+ random.shuffle(indices)
48
+ split_idx = int(len(indices) * train_ratio)
49
+
50
+ if train:
51
+ self.indices = indices[:split_idx]
52
+ else:
53
+ self.indices = indices[split_idx:]
54
+
55
+ # Image transforms
56
+ self.transform = transforms.Compose([
57
+ transforms.Resize((image_size, image_size)),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # [-1, 1]
60
+ ])
61
+
62
+ print(f"Loaded {len(self.indices)} {'training' if train else 'validation'} samples")
63
+
64
+ def _find_pairs(self) -> List[Tuple[str, str]]:
65
+ """Find all GIF-text pairs in the data directory"""
66
+ pairs = []
67
+
68
+ # Find all GIF files
69
+ gif_files = glob.glob(os.path.join(self.data_dir, "*.gif"))
70
+
71
+ for gif_path in gif_files:
72
+ # Find corresponding text file
73
+ txt_path = gif_path.replace(".gif", ".txt")
74
+
75
+ if os.path.exists(txt_path):
76
+ pairs.append((gif_path, txt_path))
77
+
78
+ return pairs
79
+
80
+ def _load_gif(self, gif_path: str) -> torch.Tensor:
81
+ """Load GIF and sample frames"""
82
+ try:
83
+ gif = Image.open(gif_path)
84
+
85
+ # Get all frames
86
+ frames = []
87
+ try:
88
+ while True:
89
+ # Convert to RGB
90
+ frame = gif.convert("RGB")
91
+ frame = self.transform(frame)
92
+ frames.append(frame)
93
+ gif.seek(gif.tell() + 1)
94
+ except EOFError:
95
+ pass
96
+
97
+ if len(frames) == 0:
98
+ raise ValueError(f"No frames found in {gif_path}")
99
+
100
+ # Sample or pad frames
101
+ if len(frames) >= self.num_frames:
102
+ # Uniform sampling
103
+ indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
104
+ frames = [frames[i] for i in indices]
105
+ else:
106
+ # Pad by repeating last frame
107
+ while len(frames) < self.num_frames:
108
+ frames.append(frames[-1])
109
+
110
+ # Stack frames: (num_frames, C, H, W)
111
+ video = torch.stack(frames)
112
+
113
+ return video
114
+
115
+ except Exception as e:
116
+ print(f"Error loading {gif_path}: {e}")
117
+ # Return random noise as fallback
118
+ return torch.randn(self.num_frames, 3, self.image_size, self.image_size)
119
+
120
+ def _load_text(self, txt_path: str) -> str:
121
+ """Load text from file"""
122
+ try:
123
+ with open(txt_path, "r", encoding="utf-8") as f:
124
+ text = f.read().strip()
125
+ return text
126
+ except Exception as e:
127
+ print(f"Error loading {txt_path}: {e}")
128
+ return ""
129
+
130
+ def __len__(self) -> int:
131
+ return len(self.indices)
132
+
133
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
134
+ real_idx = self.indices[idx]
135
+ gif_path, txt_path = self.pairs[real_idx]
136
+
137
+ video = self._load_gif(gif_path) # (T, C, H, W)
138
+ text = self._load_text(txt_path)
139
+
140
+ return {
141
+ "video": video,
142
+ "text": text,
143
+ }
144
+
145
+
146
+ class SimpleTokenizer:
147
+ """Simple tokenizer for text encoding"""
148
+
149
+ def __init__(self, vocab_size: int = 49408, max_length: int = 77):
150
+ self.vocab_size = vocab_size
151
+ self.max_length = max_length
152
+
153
+ # Simple character-level tokenization with hash
154
+ self.bos_token_id = 0
155
+ self.eos_token_id = 1
156
+ self.pad_token_id = 2
157
+
158
+ def encode(self, text: str) -> torch.Tensor:
159
+ """Encode text to token IDs"""
160
+ # Simple hash-based encoding
161
+ tokens = [self.bos_token_id]
162
+
163
+ for char in text.lower():
164
+ # Hash character to token ID
165
+ token_id = (ord(char) % (self.vocab_size - 3)) + 3
166
+ tokens.append(token_id)
167
+
168
+ if len(tokens) >= self.max_length - 1:
169
+ break
170
+
171
+ tokens.append(self.eos_token_id)
172
+
173
+ # Pad to max_length
174
+ while len(tokens) < self.max_length:
175
+ tokens.append(self.pad_token_id)
176
+
177
+ return torch.tensor(tokens[:self.max_length], dtype=torch.long)
178
+
179
+ def __call__(self, texts: List[str]) -> torch.Tensor:
180
+ """Batch encode texts"""
181
+ return torch.stack([self.encode(text) for text in texts])
182
+
183
+
184
+ def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
185
+ """Custom collate function for batching"""
186
+ tokenizer = SimpleTokenizer()
187
+
188
+ videos = torch.stack([item["video"] for item in batch])
189
+ texts = [item["text"] for item in batch]
190
+ tokens = tokenizer(texts)
191
+
192
+ return {
193
+ "video": videos, # (B, T, C, H, W)
194
+ "tokens": tokens, # (B, max_length)
195
+ "text": texts, # List of strings
196
+ }
197
+
198
+
199
+ def get_dataloader(
200
+ data_dir: str,
201
+ batch_size: int = 4,
202
+ image_size: int = 64,
203
+ num_frames: int = 16,
204
+ num_workers: int = 4,
205
+ train: bool = True,
206
+ ) -> DataLoader:
207
+ """Create dataloader for training or validation"""
208
+
209
+ dataset = SignLanguageDataset(
210
+ data_dir=data_dir,
211
+ image_size=image_size,
212
+ num_frames=num_frames,
213
+ train=train,
214
+ )
215
+
216
+ dataloader = DataLoader(
217
+ dataset,
218
+ batch_size=batch_size,
219
+ shuffle=train,
220
+ num_workers=num_workers,
221
+ collate_fn=collate_fn,
222
+ pin_memory=True,
223
+ drop_last=train,
224
+ )
225
+
226
+ return dataloader
227
+
228
+
229
+ if __name__ == "__main__":
230
+ # Test dataset
231
+ dataset = SignLanguageDataset(
232
+ data_dir="text2sign/training_data",
233
+ image_size=64,
234
+ num_frames=16,
235
+ train=True,
236
+ )
237
+
238
+ print(f"Dataset size: {len(dataset)}")
239
+
240
+ sample = dataset[0]
241
+ print(f"Video shape: {sample['video'].shape}")
242
+ print(f"Text: {sample['text']}")
inference.py CHANGED
@@ -2,11 +2,6 @@ import torch
2
  from PIL import Image
3
  import matplotlib.pyplot as plt
4
  import numpy as np
5
- import sys
6
- import os
7
-
8
- # Add model code to path if needed
9
- sys.path.append(os.path.join(os.path.dirname(__file__), "../text_to_sign"))
10
  from pipeline import Text2SignPipeline
11
 
12
  def generate_and_save(prompt, checkpoint_path, output_path, device="cuda"):
 
2
  from PIL import Image
3
  import matplotlib.pyplot as plt
4
  import numpy as np
 
 
 
 
 
5
  from pipeline import Text2SignPipeline
6
 
7
  def generate_and_save(prompt, checkpoint_path, output_path, device="cuda"):
models/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Models package for text-to-sign language generation
3
+ """
4
+
5
+ from .unet3d import UNet3D, create_unet
6
+ from .text_encoder import TextEncoder, FrozenCLIPTextEncoder, create_text_encoder
7
+
8
+ __all__ = [
9
+ "UNet3D",
10
+ "create_unet",
11
+ "TextEncoder",
12
+ "FrozenCLIPTextEncoder",
13
+ "create_text_encoder",
14
+ ]
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (445 Bytes). View file
 
models/__pycache__/text_encoder.cpython-310.pyc ADDED
Binary file (7.47 kB). View file
 
models/__pycache__/unet3d.cpython-310.pyc ADDED
Binary file (24.3 kB). View file
 
models/text_encoder.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text encoder for conditioning the diffusion model
3
+ Uses a simple transformer architecture
4
+ """
5
+
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class PositionalEncoding(nn.Module):
15
+ """Sinusoidal positional encoding"""
16
+ def __init__(self, d_model: int, max_len: int = 5000):
17
+ super().__init__()
18
+
19
+ pe = torch.zeros(max_len, d_model)
20
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
21
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
22
+
23
+ pe[:, 0::2] = torch.sin(position * div_term)
24
+ pe[:, 1::2] = torch.cos(position * div_term)
25
+ pe = pe.unsqueeze(0)
26
+
27
+ self.register_buffer('pe', pe)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ return x + self.pe[:, :x.size(1)]
31
+
32
+
33
+ class TransformerEncoderLayer(nn.Module):
34
+ """Single transformer encoder layer"""
35
+ def __init__(
36
+ self,
37
+ d_model: int,
38
+ num_heads: int,
39
+ dim_feedforward: int = 2048,
40
+ dropout: float = 0.1,
41
+ ):
42
+ super().__init__()
43
+
44
+ self.self_attn = nn.MultiheadAttention(
45
+ d_model, num_heads, dropout=dropout, batch_first=True
46
+ )
47
+
48
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
49
+ self.dropout = nn.Dropout(dropout)
50
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
51
+
52
+ self.norm1 = nn.LayerNorm(d_model)
53
+ self.norm2 = nn.LayerNorm(d_model)
54
+
55
+ self.dropout1 = nn.Dropout(dropout)
56
+ self.dropout2 = nn.Dropout(dropout)
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ mask: Optional[torch.Tensor] = None,
62
+ ) -> torch.Tensor:
63
+ # Self attention
64
+ x2, _ = self.self_attn(x, x, x, key_padding_mask=mask)
65
+ x = x + self.dropout1(x2)
66
+ x = self.norm1(x)
67
+
68
+ # Feed forward
69
+ x2 = self.linear2(self.dropout(F.gelu(self.linear1(x))))
70
+ x = x + self.dropout2(x2)
71
+ x = self.norm2(x)
72
+
73
+ return x
74
+
75
+
76
+ class TextEncoder(nn.Module):
77
+ """
78
+ Transformer-based text encoder for conditioning
79
+ Similar to CLIP text encoder but simplified
80
+ """
81
+ def __init__(
82
+ self,
83
+ vocab_size: int = 49408,
84
+ max_length: int = 77,
85
+ embed_dim: int = 512,
86
+ num_layers: int = 6,
87
+ num_heads: int = 8,
88
+ dropout: float = 0.1,
89
+ ):
90
+ super().__init__()
91
+
92
+ self.vocab_size = vocab_size
93
+ self.max_length = max_length
94
+ self.embed_dim = embed_dim
95
+
96
+ # Token embedding
97
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
98
+
99
+ # Positional encoding
100
+ self.pos_encoding = PositionalEncoding(embed_dim, max_length)
101
+
102
+ # Transformer layers
103
+ self.layers = nn.ModuleList([
104
+ TransformerEncoderLayer(
105
+ d_model=embed_dim,
106
+ num_heads=num_heads,
107
+ dim_feedforward=embed_dim * 4,
108
+ dropout=dropout,
109
+ )
110
+ for _ in range(num_layers)
111
+ ])
112
+
113
+ # Final layer norm
114
+ self.final_norm = nn.LayerNorm(embed_dim)
115
+
116
+ # Initialize weights
117
+ self._init_weights()
118
+
119
+ def _init_weights(self):
120
+ """Initialize weights"""
121
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
122
+
123
+ def forward(
124
+ self,
125
+ tokens: torch.Tensor, # (B, seq_len)
126
+ return_pooled: bool = False,
127
+ ) -> torch.Tensor:
128
+ """
129
+ Forward pass
130
+ Args:
131
+ tokens: Token IDs (B, seq_len)
132
+ return_pooled: Whether to return pooled output (first token)
133
+ Returns:
134
+ Text embeddings (B, seq_len, embed_dim) or (B, embed_dim) if pooled
135
+ """
136
+ # Token embedding
137
+ x = self.token_embedding(tokens) # (B, seq_len, embed_dim)
138
+
139
+ # Add positional encoding
140
+ x = self.pos_encoding(x)
141
+
142
+ # Create attention mask for padding (token_id == 2)
143
+ padding_mask = (tokens == 2) # pad_token_id = 2
144
+
145
+ # Transformer layers
146
+ for layer in self.layers:
147
+ x = layer(x, mask=padding_mask)
148
+
149
+ # Final norm
150
+ x = self.final_norm(x)
151
+
152
+ if return_pooled:
153
+ # Return first token embedding (like [CLS])
154
+ return x[:, 0]
155
+
156
+ return x
157
+
158
+
159
+ class FrozenCLIPTextEncoder(nn.Module):
160
+ """
161
+ Wrapper for using pretrained CLIP text encoder (if available)
162
+ Falls back to custom TextEncoder if CLIP is not available
163
+ """
164
+ def __init__(
165
+ self,
166
+ embed_dim: int = 512,
167
+ max_length: int = 77,
168
+ ):
169
+ super().__init__()
170
+
171
+ self.embed_dim = embed_dim
172
+ self.max_length = max_length
173
+
174
+ try:
175
+ from transformers import CLIPTextModel, CLIPTokenizer
176
+
177
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
178
+ self.model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
179
+
180
+ # Freeze the model
181
+ for param in self.model.parameters():
182
+ param.requires_grad = False
183
+
184
+ # Project to target dim if needed
185
+ clip_dim = self.model.config.hidden_size
186
+ if clip_dim != embed_dim:
187
+ self.proj = nn.Linear(clip_dim, embed_dim)
188
+ else:
189
+ self.proj = nn.Identity()
190
+
191
+ self.use_clip = True
192
+ print("Using pretrained CLIP text encoder")
193
+
194
+ except Exception as e:
195
+ print(f"CLIP not available ({e}), using custom text encoder")
196
+ self.model = TextEncoder(
197
+ embed_dim=embed_dim,
198
+ max_length=max_length,
199
+ )
200
+ self.proj = nn.Identity()
201
+ self.use_clip = False
202
+
203
+ def forward(
204
+ self,
205
+ tokens: torch.Tensor,
206
+ text: Optional[list] = None,
207
+ ) -> torch.Tensor:
208
+ """
209
+ Forward pass
210
+ Args:
211
+ tokens: Pre-tokenized token IDs (B, seq_len) - used if not using CLIP
212
+ text: List of text strings - used if using CLIP
213
+ Returns:
214
+ Text embeddings (B, seq_len, embed_dim)
215
+ """
216
+ if self.use_clip and text is not None:
217
+ # Tokenize with CLIP tokenizer
218
+ inputs = self.tokenizer(
219
+ text,
220
+ padding="max_length",
221
+ max_length=self.max_length,
222
+ truncation=True,
223
+ return_tensors="pt",
224
+ )
225
+ inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()}
226
+
227
+ with torch.no_grad():
228
+ outputs = self.model(**inputs)
229
+ hidden_states = outputs.last_hidden_state
230
+
231
+ return self.proj(hidden_states)
232
+ else:
233
+ return self.proj(self.model(tokens))
234
+
235
+
236
+ def create_text_encoder(config, use_clip: bool = True):
237
+ """Create text encoder from config (default: pretrained CLIP)"""
238
+ if use_clip:
239
+ return FrozenCLIPTextEncoder(
240
+ embed_dim=config.text_embed_dim,
241
+ max_length=config.max_text_length,
242
+ )
243
+ else:
244
+ return TextEncoder(
245
+ vocab_size=config.vocab_size,
246
+ max_length=config.max_text_length,
247
+ embed_dim=config.text_embed_dim,
248
+ )
249
+
250
+
251
+ if __name__ == "__main__":
252
+ # Test the encoder
253
+ encoder = TextEncoder(
254
+ vocab_size=49408,
255
+ max_length=77,
256
+ embed_dim=512,
257
+ num_layers=6,
258
+ num_heads=8,
259
+ )
260
+
261
+ # Test input
262
+ tokens = torch.randint(0, 49408, (2, 77))
263
+
264
+ # Forward pass
265
+ output = encoder(tokens)
266
+ print(f"Input shape: {tokens.shape}")
267
+ print(f"Output shape: {output.shape}")
268
+ print(f"Parameters: {sum(p.numel() for p in encoder.parameters()):,}")
models/unet3d.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D UNet architecture for video diffusion with text conditioning
3
+ Enhanced with Transformer (DiT-style) blocks for better temporal modeling
4
+
5
+ Based on:
6
+ - Diffusion Transformers (DiT) - Peebles & Xie 2023
7
+ - Video diffusion models with temporal attention
8
+ """
9
+
10
+ import math
11
+ from typing import Optional, Tuple
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int) -> torch.Tensor:
20
+ """
21
+ Create sinusoidal timestep embeddings.
22
+ """
23
+ assert len(timesteps.shape) == 1
24
+
25
+ half_dim = embedding_dim // 2
26
+ emb = math.log(10000) / (half_dim - 1)
27
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
28
+ emb = timesteps.float()[:, None] * emb[None, :]
29
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
30
+
31
+ if embedding_dim % 2 == 1: # zero pad
32
+ emb = F.pad(emb, (0, 1), mode='constant')
33
+
34
+ return emb
35
+
36
+
37
+ def get_3d_sincos_pos_embed(embed_dim: int, grid_size: Tuple[int, int, int]) -> torch.Tensor:
38
+ """
39
+ Generate 3D sinusoidal positional embeddings for video (T, H, W).
40
+ """
41
+ t, h, w = grid_size
42
+
43
+ grid_t = torch.arange(t, dtype=torch.float32)
44
+ grid_h = torch.arange(h, dtype=torch.float32)
45
+ grid_w = torch.arange(w, dtype=torch.float32)
46
+
47
+ grid = torch.meshgrid(grid_t, grid_h, grid_w, indexing='ij')
48
+ grid = torch.stack(grid, dim=0) # (3, T, H, W)
49
+ grid = grid.reshape(3, -1).T # (T*H*W, 3)
50
+
51
+ # Split embedding dim across 3 dimensions
52
+ dim_t = embed_dim // 3
53
+ dim_h = embed_dim // 3
54
+ dim_w = embed_dim - dim_t - dim_h
55
+
56
+ def get_1d_sincos(positions, dim):
57
+ omega = torch.arange(dim // 2, dtype=torch.float32)
58
+ omega = 1.0 / (10000 ** (omega / (dim // 2)))
59
+ out = positions[:, None] * omega[None, :]
60
+ return torch.cat([torch.sin(out), torch.cos(out)], dim=1)
61
+
62
+ emb_t = get_1d_sincos(grid[:, 0], dim_t)
63
+ emb_h = get_1d_sincos(grid[:, 1], dim_h)
64
+ emb_w = get_1d_sincos(grid[:, 2], dim_w)
65
+
66
+ return torch.cat([emb_t, emb_h, emb_w], dim=1) # (T*H*W, embed_dim)
67
+
68
+
69
+ class GroupNorm32(nn.GroupNorm):
70
+ """GroupNorm with float32 computation for stability"""
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ return super().forward(x.float()).type(x.dtype)
73
+
74
+
75
+ class RMSNorm(nn.Module):
76
+ """Root Mean Square Layer Normalization (more efficient than LayerNorm)"""
77
+ def __init__(self, dim: int, eps: float = 1e-6):
78
+ super().__init__()
79
+ self.eps = eps
80
+ self.weight = nn.Parameter(torch.ones(dim))
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
84
+ return x / rms * self.weight
85
+
86
+
87
+ class AdaLayerNorm(nn.Module):
88
+ """Adaptive Layer Normalization conditioned on timestep (DiT-style)"""
89
+ def __init__(self, dim: int, time_embed_dim: int):
90
+ super().__init__()
91
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False)
92
+ self.proj = nn.Linear(time_embed_dim, dim * 2)
93
+
94
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
95
+ # t_emb: (B, time_embed_dim)
96
+ scale_shift = self.proj(t_emb)
97
+ scale, shift = scale_shift.chunk(2, dim=-1)
98
+
99
+ # Handle different input shapes
100
+ if x.dim() == 3: # (B, N, C)
101
+ scale = scale.unsqueeze(1)
102
+ shift = shift.unsqueeze(1)
103
+ elif x.dim() == 5: # (B, C, T, H, W)
104
+ scale = scale[:, :, None, None, None]
105
+ shift = shift[:, :, None, None, None]
106
+
107
+ return self.norm(x) * (1 + scale) + shift
108
+
109
+
110
+ class AdaLayerNormZero(nn.Module):
111
+ """Adaptive Layer Normalization with zero-init (DiT-style)"""
112
+ def __init__(self, dim: int, time_embed_dim: int):
113
+ super().__init__()
114
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False)
115
+ self.proj = nn.Linear(time_embed_dim, dim * 6) # scale, shift, gate for both attn and ff
116
+ nn.init.zeros_(self.proj.weight)
117
+ nn.init.zeros_(self.proj.bias)
118
+
119
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> Tuple[torch.Tensor, ...]:
120
+ params = self.proj(t_emb)
121
+ return self.norm(x), params.chunk(6, dim=-1)
122
+
123
+
124
+ class Upsample3D(nn.Module):
125
+ """3D Upsampling with convolution"""
126
+ def __init__(self, channels: int):
127
+ super().__init__()
128
+ self.conv = nn.Conv3d(channels, channels, 3, padding=1)
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ x = F.interpolate(x, scale_factor=(1, 2, 2), mode='nearest')
132
+ return self.conv(x)
133
+
134
+
135
+ class Downsample3D(nn.Module):
136
+ """3D Downsampling with convolution"""
137
+ def __init__(self, channels: int):
138
+ super().__init__()
139
+ self.conv = nn.Conv3d(channels, channels, 3, stride=(1, 2, 2), padding=1)
140
+
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+ return self.conv(x)
143
+
144
+
145
+ class ResBlock3D(nn.Module):
146
+ """3D Residual block with time and context conditioning"""
147
+ def __init__(
148
+ self,
149
+ in_channels: int,
150
+ out_channels: int,
151
+ time_emb_dim: int,
152
+ dropout: float = 0.1,
153
+ ):
154
+ super().__init__()
155
+
156
+ self.in_layers = nn.Sequential(
157
+ GroupNorm32(32, in_channels),
158
+ nn.SiLU(),
159
+ nn.Conv3d(in_channels, out_channels, 3, padding=1),
160
+ )
161
+
162
+ self.time_emb_proj = nn.Sequential(
163
+ nn.SiLU(),
164
+ nn.Linear(time_emb_dim, out_channels),
165
+ )
166
+
167
+ self.out_layers = nn.Sequential(
168
+ GroupNorm32(32, out_channels),
169
+ nn.SiLU(),
170
+ nn.Dropout(dropout),
171
+ nn.Conv3d(out_channels, out_channels, 3, padding=1),
172
+ )
173
+
174
+ if in_channels != out_channels:
175
+ self.skip_connection = nn.Conv3d(in_channels, out_channels, 1)
176
+ else:
177
+ self.skip_connection = nn.Identity()
178
+
179
+ def forward(
180
+ self,
181
+ x: torch.Tensor,
182
+ time_emb: torch.Tensor,
183
+ ) -> torch.Tensor:
184
+ h = self.in_layers(x)
185
+
186
+ # Add time embedding
187
+ time_emb = self.time_emb_proj(time_emb)
188
+ h = h + time_emb[:, :, None, None, None]
189
+
190
+ h = self.out_layers(h)
191
+
192
+ return self.skip_connection(x) + h
193
+
194
+
195
+ class SpatialAttention(nn.Module):
196
+ """Self-attention over spatial dimensions"""
197
+ def __init__(self, channels: int, num_heads: int = 8):
198
+ super().__init__()
199
+ self.num_heads = num_heads
200
+ self.head_dim = channels // num_heads
201
+
202
+ self.norm = GroupNorm32(32, channels)
203
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
204
+ self.proj = nn.Conv1d(channels, channels, 1)
205
+
206
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
207
+ b, c, t, h, w = x.shape
208
+
209
+ # Reshape to (B*T, C, H*W)
210
+ x_flat = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h * w)
211
+
212
+ # Normalize
213
+ x_norm = self.norm(x_flat.view(b * t, c, h, w)).view(b * t, c, h * w)
214
+
215
+ # QKV projection
216
+ qkv = self.qkv(x_norm)
217
+ q, k, v = qkv.chunk(3, dim=1)
218
+
219
+ # Reshape for multi-head attention
220
+ q = q.view(b * t, self.num_heads, self.head_dim, h * w).permute(0, 1, 3, 2)
221
+ k = k.view(b * t, self.num_heads, self.head_dim, h * w).permute(0, 1, 3, 2)
222
+ v = v.view(b * t, self.num_heads, self.head_dim, h * w).permute(0, 1, 3, 2)
223
+
224
+ # Attention
225
+ scale = self.head_dim ** -0.5
226
+ attn = torch.matmul(q, k.transpose(-2, -1)) * scale
227
+ attn = F.softmax(attn, dim=-1)
228
+
229
+ out = torch.matmul(attn, v)
230
+ out = out.permute(0, 1, 3, 2).reshape(b * t, c, h * w)
231
+
232
+ out = self.proj(out)
233
+ out = out.view(b, t, c, h, w).permute(0, 2, 1, 3, 4)
234
+
235
+ return x + out
236
+
237
+
238
+ class CrossAttention(nn.Module):
239
+ """Cross-attention for text conditioning"""
240
+ def __init__(
241
+ self,
242
+ query_dim: int,
243
+ context_dim: int,
244
+ num_heads: int = 8,
245
+ head_dim: int = 64,
246
+ ):
247
+ super().__init__()
248
+ self.num_heads = num_heads
249
+ self.head_dim = head_dim
250
+ inner_dim = head_dim * num_heads
251
+
252
+ self.norm = GroupNorm32(32, query_dim)
253
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
254
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
255
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
256
+ self.to_out = nn.Sequential(
257
+ nn.Linear(inner_dim, query_dim),
258
+ nn.Dropout(0.1),
259
+ )
260
+
261
+ def forward(
262
+ self,
263
+ x: torch.Tensor,
264
+ context: torch.Tensor,
265
+ ) -> torch.Tensor:
266
+ b, c, t, h, w = x.shape
267
+
268
+ # Reshape to (B, T*H*W, C)
269
+ x_flat = x.permute(0, 2, 3, 4, 1).reshape(b, t * h * w, c)
270
+
271
+ # Normalize
272
+ x_norm = self.norm(x.view(b, c, -1)).permute(0, 2, 1)
273
+
274
+ # QKV
275
+ q = self.to_q(x_norm)
276
+ k = self.to_k(context)
277
+ v = self.to_v(context)
278
+
279
+ # Reshape for multi-head
280
+ q = q.view(b, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
281
+ k = k.view(b, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
282
+ v = v.view(b, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
283
+
284
+ # Attention
285
+ scale = self.head_dim ** -0.5
286
+ attn = torch.matmul(q, k.transpose(-2, -1)) * scale
287
+ attn = F.softmax(attn, dim=-1)
288
+
289
+ out = torch.matmul(attn, v)
290
+ out = out.permute(0, 2, 1, 3).reshape(b, t * h * w, -1)
291
+ out = self.to_out(out)
292
+
293
+ out = out.view(b, t, h, w, c).permute(0, 4, 1, 2, 3)
294
+
295
+ return x + out
296
+
297
+
298
+ class TemporalAttention(nn.Module):
299
+ """Self-attention over temporal dimension"""
300
+ def __init__(self, channels: int, num_heads: int = 8):
301
+ super().__init__()
302
+ self.num_heads = num_heads
303
+ self.head_dim = channels // num_heads
304
+
305
+ self.norm = GroupNorm32(32, channels)
306
+ self.qkv = nn.Linear(channels, channels * 3)
307
+ self.proj = nn.Linear(channels, channels)
308
+
309
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
310
+ b, c, t, h, w = x.shape
311
+
312
+ # Reshape to (B*H*W, T, C)
313
+ x_flat = x.permute(0, 3, 4, 2, 1).reshape(b * h * w, t, c)
314
+
315
+ # Normalize
316
+ x_norm = self.norm(x.view(b, c, -1)).view(b, c, t, h, w)
317
+ x_norm = x_norm.permute(0, 3, 4, 2, 1).reshape(b * h * w, t, c)
318
+
319
+ # QKV
320
+ qkv = self.qkv(x_norm)
321
+ q, k, v = qkv.chunk(3, dim=-1)
322
+
323
+ # Reshape for multi-head
324
+ q = q.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
325
+ k = k.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
326
+ v = v.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
327
+
328
+ # Attention
329
+ scale = self.head_dim ** -0.5
330
+ attn = torch.matmul(q, k.transpose(-2, -1)) * scale
331
+ attn = F.softmax(attn, dim=-1)
332
+
333
+ out = torch.matmul(attn, v)
334
+ out = out.permute(0, 2, 1, 3).reshape(b * h * w, t, c)
335
+ out = self.proj(out)
336
+
337
+ out = out.view(b, h, w, t, c).permute(0, 4, 3, 1, 2)
338
+
339
+ return x + out
340
+
341
+
342
+ # ============================================================================
343
+ # Transformer Components (DiT-style)
344
+ # ============================================================================
345
+
346
+ class MultiHeadAttention(nn.Module):
347
+ """
348
+ Multi-head attention with optional flash attention and rotary embeddings.
349
+ Supports both self-attention and cross-attention.
350
+ """
351
+ def __init__(
352
+ self,
353
+ dim: int,
354
+ num_heads: int = 8,
355
+ qkv_bias: bool = True,
356
+ attn_drop: float = 0.0,
357
+ proj_drop: float = 0.0,
358
+ is_cross_attention: bool = False,
359
+ context_dim: Optional[int] = None,
360
+ ):
361
+ super().__init__()
362
+ self.num_heads = num_heads
363
+ self.head_dim = dim // num_heads
364
+ self.scale = self.head_dim ** -0.5
365
+ self.is_cross_attention = is_cross_attention
366
+
367
+ if is_cross_attention:
368
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
369
+ self.to_kv = nn.Linear(context_dim or dim, dim * 2, bias=qkv_bias)
370
+ else:
371
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
372
+
373
+ self.attn_drop = nn.Dropout(attn_drop)
374
+ self.proj = nn.Linear(dim, dim)
375
+ self.proj_drop = nn.Dropout(proj_drop)
376
+
377
+ def forward(
378
+ self,
379
+ x: torch.Tensor,
380
+ context: Optional[torch.Tensor] = None,
381
+ ) -> torch.Tensor:
382
+ B, N, C = x.shape
383
+
384
+ if self.is_cross_attention and context is not None:
385
+ q = self.to_q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
386
+ kv = self.to_kv(context).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
387
+ k, v = kv[0], kv[1]
388
+ else:
389
+ qkv = self.to_qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
390
+ q, k, v = qkv[0], qkv[1], qkv[2]
391
+
392
+ # Scaled dot-product attention
393
+ attn = (q @ k.transpose(-2, -1)) * self.scale
394
+ attn = attn.softmax(dim=-1)
395
+ attn = self.attn_drop(attn)
396
+
397
+ out = (attn @ v).transpose(1, 2).reshape(B, N, C)
398
+ out = self.proj(out)
399
+ out = self.proj_drop(out)
400
+
401
+ return out
402
+
403
+
404
+ class FeedForward(nn.Module):
405
+ """Feed-forward network with GELU activation"""
406
+ def __init__(
407
+ self,
408
+ dim: int,
409
+ hidden_dim: Optional[int] = None,
410
+ dropout: float = 0.0,
411
+ ):
412
+ super().__init__()
413
+ hidden_dim = hidden_dim or dim * 4
414
+ self.net = nn.Sequential(
415
+ nn.Linear(dim, hidden_dim),
416
+ nn.GELU(),
417
+ nn.Dropout(dropout),
418
+ nn.Linear(hidden_dim, dim),
419
+ nn.Dropout(dropout),
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ return self.net(x)
424
+
425
+
426
+ class DiTBlock(nn.Module):
427
+ """
428
+ Diffusion Transformer Block (DiT-style).
429
+ Uses adaptive layer norm for timestep conditioning.
430
+ """
431
+ def __init__(
432
+ self,
433
+ dim: int,
434
+ num_heads: int,
435
+ time_embed_dim: int,
436
+ mlp_ratio: float = 4.0,
437
+ dropout: float = 0.0,
438
+ context_dim: Optional[int] = None,
439
+ ):
440
+ super().__init__()
441
+
442
+ # Self-attention with adaptive norm
443
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
444
+ self.attn = MultiHeadAttention(dim, num_heads, attn_drop=dropout, proj_drop=dropout)
445
+
446
+ # Cross-attention for text conditioning
447
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
448
+ self.cross_attn = MultiHeadAttention(
449
+ dim, num_heads,
450
+ attn_drop=dropout,
451
+ proj_drop=dropout,
452
+ is_cross_attention=True,
453
+ context_dim=context_dim,
454
+ )
455
+
456
+ # Feed-forward with adaptive norm
457
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=False)
458
+ self.ff = FeedForward(dim, int(dim * mlp_ratio), dropout)
459
+
460
+ # Adaptive parameters (DiT-style)
461
+ self.adaLN_modulation = nn.Sequential(
462
+ nn.SiLU(),
463
+ nn.Linear(time_embed_dim, dim * 9), # 3 params each for 3 blocks
464
+ )
465
+ nn.init.zeros_(self.adaLN_modulation[-1].weight)
466
+ nn.init.zeros_(self.adaLN_modulation[-1].bias)
467
+
468
+ def forward(
469
+ self,
470
+ x: torch.Tensor,
471
+ t_emb: torch.Tensor,
472
+ context: Optional[torch.Tensor] = None,
473
+ ) -> torch.Tensor:
474
+ # Get adaptive parameters
475
+ params = self.adaLN_modulation(t_emb)
476
+ (
477
+ scale1, shift1, gate1,
478
+ scale2, shift2, gate2,
479
+ scale3, shift3, gate3,
480
+ ) = params.unsqueeze(1).chunk(9, dim=-1)
481
+
482
+ # Self-attention
483
+ x_norm = self.norm1(x) * (1 + scale1) + shift1
484
+ x = x + gate1 * self.attn(x_norm)
485
+
486
+ # Cross-attention
487
+ if context is not None:
488
+ x_norm = self.norm2(x) * (1 + scale2) + shift2
489
+ x = x + gate2 * self.cross_attn(x_norm, context)
490
+
491
+ # Feed-forward
492
+ x_norm = self.norm3(x) * (1 + scale3) + shift3
493
+ x = x + gate3 * self.ff(x_norm)
494
+
495
+ return x
496
+
497
+
498
+ class TemporalTransformerBlock(nn.Module):
499
+ """
500
+ Transformer block specifically for temporal attention.
501
+ Processes video frames attending to other frames.
502
+ """
503
+ def __init__(
504
+ self,
505
+ dim: int,
506
+ num_heads: int,
507
+ time_embed_dim: int,
508
+ dropout: float = 0.0,
509
+ ):
510
+ super().__init__()
511
+
512
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False)
513
+ self.attn = MultiHeadAttention(dim, num_heads, attn_drop=dropout, proj_drop=dropout)
514
+
515
+ # Adaptive parameters
516
+ self.adaLN_modulation = nn.Sequential(
517
+ nn.SiLU(),
518
+ nn.Linear(time_embed_dim, dim * 3),
519
+ )
520
+ nn.init.zeros_(self.adaLN_modulation[-1].weight)
521
+ nn.init.zeros_(self.adaLN_modulation[-1].bias)
522
+
523
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
524
+ """
525
+ Args:
526
+ x: (B, T, C) temporal sequence
527
+ t_emb: (B, time_embed_dim) timestep embedding
528
+ """
529
+ params = self.adaLN_modulation(t_emb)
530
+ scale, shift, gate = params.unsqueeze(1).chunk(3, dim=-1)
531
+
532
+ x_norm = self.norm(x) * (1 + scale) + shift
533
+ x = x + gate * self.attn(x_norm)
534
+
535
+ return x
536
+
537
+
538
+ class SpatioTemporalTransformer(nn.Module):
539
+ """
540
+ Combined spatial and temporal transformer for video understanding.
541
+ First applies spatial attention within each frame, then temporal attention across frames.
542
+ """
543
+ def __init__(
544
+ self,
545
+ dim: int,
546
+ num_heads: int,
547
+ time_embed_dim: int,
548
+ context_dim: int,
549
+ depth: int = 2,
550
+ dropout: float = 0.0,
551
+ ):
552
+ super().__init__()
553
+
554
+ self.spatial_blocks = nn.ModuleList([
555
+ DiTBlock(dim, num_heads, time_embed_dim, dropout=dropout, context_dim=context_dim)
556
+ for _ in range(depth)
557
+ ])
558
+
559
+ self.temporal_blocks = nn.ModuleList([
560
+ TemporalTransformerBlock(dim, num_heads, time_embed_dim, dropout)
561
+ for _ in range(depth)
562
+ ])
563
+
564
+ def forward(
565
+ self,
566
+ x: torch.Tensor, # (B, C, T, H, W)
567
+ t_emb: torch.Tensor, # (B, time_embed_dim)
568
+ context: torch.Tensor, # (B, seq_len, context_dim)
569
+ ) -> torch.Tensor:
570
+ B, C, T, H, W = x.shape
571
+
572
+ # Spatial attention: process each frame
573
+ # Reshape to (B*T, H*W, C)
574
+ x_spatial = rearrange(x, 'b c t h w -> (b t) (h w) c')
575
+ t_emb_spatial = repeat(t_emb, 'b d -> (b t) d', t=T)
576
+ context_spatial = repeat(context, 'b n d -> (b t) n d', t=T)
577
+
578
+ for block in self.spatial_blocks:
579
+ x_spatial = block(x_spatial, t_emb_spatial, context_spatial)
580
+
581
+ # Reshape back: (B, T, H*W, C)
582
+ x_spatial = rearrange(x_spatial, '(b t) n c -> b t n c', b=B, t=T)
583
+
584
+ # Temporal attention: process each spatial location
585
+ # Reshape to (B*H*W, T, C)
586
+ x_temporal = rearrange(x_spatial, 'b t n c -> (b n) t c', n=H*W)
587
+ t_emb_temporal = repeat(t_emb, 'b d -> (b n) d', n=H*W)
588
+
589
+ for block in self.temporal_blocks:
590
+ x_temporal = block(x_temporal, t_emb_temporal)
591
+
592
+ # Reshape back to (B, C, T, H, W)
593
+ x_out = rearrange(x_temporal, '(b h w) t c -> b c t h w', b=B, h=H, w=W)
594
+
595
+ return x_out
596
+
597
+
598
+ class TransformerBlock3D(nn.Module):
599
+ """
600
+ Enhanced Transformer block with spatial, temporal, and cross attention.
601
+ Uses DiT-style adaptive layer norm for better timestep conditioning.
602
+ """
603
+ def __init__(
604
+ self,
605
+ channels: int,
606
+ context_dim: int,
607
+ time_embed_dim: int,
608
+ num_heads: int = 8,
609
+ transformer_depth: int = 1,
610
+ use_spatio_temporal: bool = True,
611
+ ):
612
+ super().__init__()
613
+
614
+ self.use_spatio_temporal = use_spatio_temporal
615
+
616
+ if use_spatio_temporal:
617
+ # Use the new SpatioTemporalTransformer
618
+ self.transformer = SpatioTemporalTransformer(
619
+ dim=channels,
620
+ num_heads=num_heads,
621
+ time_embed_dim=time_embed_dim,
622
+ context_dim=context_dim,
623
+ depth=transformer_depth,
624
+ )
625
+ else:
626
+ # Fallback to simpler attention
627
+ self.spatial_attn = SpatialAttention(channels, num_heads)
628
+ self.temporal_attn = TemporalAttention(channels, num_heads)
629
+ self.cross_attn = CrossAttention(
630
+ query_dim=channels,
631
+ context_dim=context_dim,
632
+ num_heads=num_heads,
633
+ )
634
+
635
+ # Feed-forward (used in both cases)
636
+ self.ff = nn.Sequential(
637
+ GroupNorm32(32, channels),
638
+ nn.Conv3d(channels, channels * 4, 1),
639
+ nn.GELU(),
640
+ nn.Conv3d(channels * 4, channels, 1),
641
+ )
642
+
643
+ def forward(
644
+ self,
645
+ x: torch.Tensor,
646
+ context: torch.Tensor,
647
+ t_emb: Optional[torch.Tensor] = None,
648
+ ) -> torch.Tensor:
649
+ if self.use_spatio_temporal and t_emb is not None:
650
+ x = self.transformer(x, t_emb, context)
651
+ else:
652
+ x = self.spatial_attn(x)
653
+ x = self.temporal_attn(x)
654
+ x = self.cross_attn(x, context)
655
+
656
+ x = x + self.ff(x)
657
+ return x
658
+
659
+
660
+ class TemporalAttention(nn.Module):
661
+ """Self-attention over temporal dimension (legacy, for backward compatibility)"""
662
+ def __init__(self, channels: int, num_heads: int = 8):
663
+ super().__init__()
664
+ self.num_heads = num_heads
665
+ self.head_dim = channels // num_heads
666
+
667
+ self.norm = GroupNorm32(32, channels)
668
+ self.qkv = nn.Linear(channels, channels * 3)
669
+ self.proj = nn.Linear(channels, channels)
670
+
671
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
672
+ b, c, t, h, w = x.shape
673
+
674
+ # Reshape to (B*H*W, T, C)
675
+ x_flat = x.permute(0, 3, 4, 2, 1).reshape(b * h * w, t, c)
676
+
677
+ # Normalize
678
+ x_norm = self.norm(x.view(b, c, -1)).view(b, c, t, h, w)
679
+ x_norm = x_norm.permute(0, 3, 4, 2, 1).reshape(b * h * w, t, c)
680
+
681
+ # QKV
682
+ qkv = self.qkv(x_norm)
683
+ q, k, v = qkv.chunk(3, dim=-1)
684
+
685
+ # Reshape for multi-head
686
+ q = q.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
687
+ k = k.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
688
+ v = v.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
689
+
690
+ # Attention
691
+ scale = self.head_dim ** -0.5
692
+ attn = torch.matmul(q, k.transpose(-2, -1)) * scale
693
+ attn = F.softmax(attn, dim=-1)
694
+
695
+ out = torch.matmul(attn, v)
696
+ out = out.permute(0, 2, 1, 3).reshape(b * h * w, t, c)
697
+ out = self.proj(out)
698
+
699
+ out = out.view(b, h, w, t, c).permute(0, 4, 3, 1, 2)
700
+
701
+ return x + out
702
+
703
+
704
+ class UNet3D(nn.Module):
705
+ """
706
+ 3D UNet for video diffusion with text conditioning.
707
+ Enhanced with DiT-style transformer blocks for better temporal modeling.
708
+ """
709
+ def __init__(
710
+ self,
711
+ in_channels: int = 3,
712
+ model_channels: int = 128,
713
+ out_channels: int = 3,
714
+ num_res_blocks: int = 2,
715
+ attention_resolutions: Tuple[int, ...] = (8, 16),
716
+ channel_mult: Tuple[int, ...] = (1, 2, 4, 8),
717
+ num_heads: int = 8,
718
+ context_dim: int = 512,
719
+ dropout: float = 0.1,
720
+ use_transformer: bool = True, # Use enhanced transformer blocks
721
+ transformer_depth: int = 1, # Depth of transformer blocks
722
+ use_gradient_checkpointing: bool = False, # Enable gradient checkpointing for memory
723
+ ):
724
+ super().__init__()
725
+
726
+ self.in_channels = in_channels
727
+ self.model_channels = model_channels
728
+ self.out_channels = out_channels
729
+ self.num_res_blocks = num_res_blocks
730
+ self.attention_resolutions = attention_resolutions
731
+ self.channel_mult = channel_mult
732
+ self.num_heads = num_heads
733
+ self.use_transformer = use_transformer
734
+ self.use_gradient_checkpointing = use_gradient_checkpointing
735
+
736
+ time_embed_dim = model_channels * 4
737
+ self.time_embed_dim = time_embed_dim
738
+
739
+ # Time embedding
740
+ self.time_embed = nn.Sequential(
741
+ nn.Linear(model_channels, time_embed_dim),
742
+ nn.SiLU(),
743
+ nn.Linear(time_embed_dim, time_embed_dim),
744
+ )
745
+
746
+ # Input convolution
747
+ self.input_blocks = nn.ModuleList([
748
+ nn.Conv3d(in_channels, model_channels, 3, padding=1)
749
+ ])
750
+
751
+ # Downsampling
752
+ ch = model_channels
753
+ input_block_chans = [ch]
754
+ ds = 1
755
+
756
+ for level, mult in enumerate(channel_mult):
757
+ for _ in range(num_res_blocks):
758
+ layers = [
759
+ ResBlock3D(ch, mult * model_channels, time_embed_dim, dropout)
760
+ ]
761
+ ch = mult * model_channels
762
+
763
+ if ds in attention_resolutions:
764
+ layers.append(
765
+ TransformerBlock3D(
766
+ channels=ch,
767
+ context_dim=context_dim,
768
+ time_embed_dim=time_embed_dim,
769
+ num_heads=num_heads,
770
+ transformer_depth=transformer_depth,
771
+ use_spatio_temporal=use_transformer,
772
+ )
773
+ )
774
+
775
+ self.input_blocks.append(nn.ModuleList(layers))
776
+ input_block_chans.append(ch)
777
+
778
+ if level != len(channel_mult) - 1:
779
+ self.input_blocks.append(nn.ModuleList([Downsample3D(ch)]))
780
+ input_block_chans.append(ch)
781
+ ds *= 2
782
+
783
+ # Middle
784
+ self.middle_block = nn.ModuleList([
785
+ ResBlock3D(ch, ch, time_embed_dim, dropout),
786
+ TransformerBlock3D(
787
+ channels=ch,
788
+ context_dim=context_dim,
789
+ time_embed_dim=time_embed_dim,
790
+ num_heads=num_heads,
791
+ transformer_depth=transformer_depth,
792
+ use_spatio_temporal=use_transformer,
793
+ ),
794
+ ResBlock3D(ch, ch, time_embed_dim, dropout),
795
+ ])
796
+
797
+ # Upsampling
798
+ self.output_blocks = nn.ModuleList([])
799
+
800
+ for level, mult in list(enumerate(channel_mult))[::-1]:
801
+ for i in range(num_res_blocks + 1):
802
+ ich = input_block_chans.pop()
803
+ layers = [
804
+ ResBlock3D(ch + ich, mult * model_channels, time_embed_dim, dropout)
805
+ ]
806
+ ch = mult * model_channels
807
+
808
+ if ds in attention_resolutions:
809
+ layers.append(
810
+ TransformerBlock3D(
811
+ channels=ch,
812
+ context_dim=context_dim,
813
+ time_embed_dim=time_embed_dim,
814
+ num_heads=num_heads,
815
+ transformer_depth=transformer_depth,
816
+ use_spatio_temporal=use_transformer,
817
+ )
818
+ )
819
+
820
+ if level and i == num_res_blocks:
821
+ layers.append(Upsample3D(ch))
822
+ ds //= 2
823
+
824
+ self.output_blocks.append(nn.ModuleList(layers))
825
+
826
+ # Output
827
+ self.out = nn.Sequential(
828
+ GroupNorm32(32, ch),
829
+ nn.SiLU(),
830
+ nn.Conv3d(ch, out_channels, 3, padding=1),
831
+ )
832
+
833
+ def _checkpoint_forward(self, layer, h, t_emb, context=None):
834
+ """Helper for gradient checkpointing"""
835
+ if isinstance(layer, ResBlock3D):
836
+ return layer(h, t_emb)
837
+ elif isinstance(layer, TransformerBlock3D):
838
+ return layer(h, context, t_emb)
839
+ elif isinstance(layer, (Downsample3D, Upsample3D)):
840
+ return layer(h)
841
+ return h
842
+
843
+ def forward(
844
+ self,
845
+ x: torch.Tensor, # (B, C, T, H, W)
846
+ timesteps: torch.Tensor, # (B,)
847
+ context: torch.Tensor, # (B, seq_len, context_dim)
848
+ ) -> torch.Tensor:
849
+ """
850
+ Forward pass
851
+ Args:
852
+ x: Noisy video tensor (B, C, T, H, W)
853
+ timesteps: Diffusion timesteps (B,)
854
+ context: Text embeddings (B, seq_len, context_dim)
855
+ Returns:
856
+ Predicted noise (B, C, T, H, W)
857
+ """
858
+ from torch.utils.checkpoint import checkpoint
859
+
860
+ # Time embedding
861
+ t_emb = get_timestep_embedding(timesteps, self.model_channels)
862
+ t_emb = self.time_embed(t_emb)
863
+
864
+ # Downsampling path
865
+ hs = []
866
+ h = x
867
+
868
+ for module in self.input_blocks:
869
+ if isinstance(module, nn.Conv3d):
870
+ h = module(h)
871
+ elif isinstance(module, nn.ModuleList):
872
+ for layer in module:
873
+ if self.use_gradient_checkpointing and self.training:
874
+ h = checkpoint(self._checkpoint_forward, layer, h, t_emb, context, use_reentrant=False)
875
+ else:
876
+ h = self._checkpoint_forward(layer, h, t_emb, context)
877
+ hs.append(h)
878
+
879
+ # Middle
880
+ for layer in self.middle_block:
881
+ if self.use_gradient_checkpointing and self.training:
882
+ h = checkpoint(self._checkpoint_forward, layer, h, t_emb, context, use_reentrant=False)
883
+ else:
884
+ h = self._checkpoint_forward(layer, h, t_emb, context)
885
+
886
+ # Upsampling path
887
+ for module in self.output_blocks:
888
+ h = torch.cat([h, hs.pop()], dim=1)
889
+ for layer in module:
890
+ if self.use_gradient_checkpointing and self.training:
891
+ h = checkpoint(self._checkpoint_forward, layer, h, t_emb, context, use_reentrant=False)
892
+ else:
893
+ h = self._checkpoint_forward(layer, h, t_emb, context)
894
+
895
+ return self.out(h)
896
+
897
+
898
+ def create_unet(config) -> UNet3D:
899
+ """Create UNet model from config"""
900
+ return UNet3D(
901
+ in_channels=config.in_channels,
902
+ model_channels=config.model_channels,
903
+ out_channels=config.in_channels,
904
+ num_res_blocks=config.num_res_blocks,
905
+ attention_resolutions=config.attention_resolutions,
906
+ channel_mult=config.channel_mult,
907
+ num_heads=config.num_heads,
908
+ context_dim=config.context_dim,
909
+ use_transformer=getattr(config, 'use_transformer', True),
910
+ transformer_depth=getattr(config, 'transformer_depth', 1),
911
+ use_gradient_checkpointing=getattr(config, 'use_gradient_checkpointing', False),
912
+ )
913
+
914
+
915
+ if __name__ == "__main__":
916
+ # Test the enhanced model with transformer blocks
917
+ print("Testing UNet3D with DiT-style Transformer blocks...")
918
+
919
+ model = UNet3D(
920
+ in_channels=3,
921
+ model_channels=64,
922
+ channel_mult=(1, 2, 4),
923
+ attention_resolutions=(8, 16),
924
+ num_heads=4,
925
+ context_dim=256,
926
+ use_transformer=True,
927
+ transformer_depth=1,
928
+ )
929
+
930
+ # Test input
931
+ batch_size = 2
932
+ x = torch.randn(batch_size, 3, 16, 64, 64) # (B, C, T, H, W)
933
+ t = torch.randint(0, 1000, (batch_size,))
934
+ context = torch.randn(batch_size, 77, 256) # (B, seq_len, context_dim)
935
+
936
+ # Forward pass
937
+ out = model(x, t, context)
938
+ print(f"Input shape: {x.shape}")
939
+ print(f"Output shape: {out.shape}")
940
+ print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
941
+
942
+ # Test backward pass
943
+ loss = out.sum()
944
+ loss.backward()
945
+ print("Backward pass successful!")
946
+
947
+ # Test without transformer (legacy mode)
948
+ print("\nTesting UNet3D without transformer (legacy mode)...")
949
+ model_legacy = UNet3D(
950
+ in_channels=3,
951
+ model_channels=64,
952
+ channel_mult=(1, 2, 4),
953
+ attention_resolutions=(8, 16),
954
+ num_heads=4,
955
+ context_dim=256,
956
+ use_transformer=False,
957
+ )
958
+
959
+ out_legacy = model_legacy(x, t, context)
960
+ print(f"Legacy output shape: {out_legacy.shape}")
961
+ print(f"Legacy parameters: {sum(p.numel() for p in model_legacy.parameters()):,}")
pipeline.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pipeline for text-to-sign language GIF generation
3
+ End-to-end inference with a trained model
4
+ """
5
+
6
+ import os
7
+ from typing import List, Optional, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+
15
+ from config import ModelConfig, DDIMConfig, GenerationConfig
16
+ from models import UNet3D, TextEncoder, create_text_encoder
17
+ from schedulers import DDIMScheduler
18
+ from dataset import SimpleTokenizer
19
+
20
+
21
+ class Text2SignPipeline:
22
+ """
23
+ End-to-end pipeline for text-to-sign language GIF generation
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ model: UNet3D,
29
+ text_encoder: TextEncoder,
30
+ scheduler: DDIMScheduler,
31
+ model_config: ModelConfig,
32
+ generation_config: GenerationConfig,
33
+ device: Union[str, torch.device] = "cuda",
34
+ ):
35
+ self.model = model.to(device)
36
+ self.text_encoder = text_encoder.to(device)
37
+ self.scheduler = scheduler
38
+ self.model_config = model_config
39
+ self.generation_config = generation_config
40
+ self.device = device
41
+ self.use_clip_text_encoder = getattr(model_config, "use_clip_text_encoder", False) or getattr(text_encoder, "use_clip", False)
42
+
43
+ # Move scheduler tensors to device
44
+ self._move_scheduler_to_device()
45
+
46
+ # Tokenizer
47
+ self.tokenizer = None if self.use_clip_text_encoder else SimpleTokenizer(
48
+ vocab_size=model_config.vocab_size,
49
+ max_length=model_config.max_text_length,
50
+ )
51
+
52
+ # Set models to eval mode
53
+ self.model.eval()
54
+ self.text_encoder.eval()
55
+
56
+ def _move_scheduler_to_device(self):
57
+ """Move scheduler tensors to device"""
58
+ self.scheduler.betas = self.scheduler.betas.to(self.device)
59
+ self.scheduler.alphas = self.scheduler.alphas.to(self.device)
60
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
61
+ self.scheduler.alphas_cumprod_prev = self.scheduler.alphas_cumprod_prev.to(self.device)
62
+ self.scheduler.sqrt_alphas_cumprod = self.scheduler.sqrt_alphas_cumprod.to(self.device)
63
+ self.scheduler.sqrt_one_minus_alphas_cumprod = self.scheduler.sqrt_one_minus_alphas_cumprod.to(self.device)
64
+
65
+ @classmethod
66
+ def from_pretrained(
67
+ cls,
68
+ checkpoint_path: str,
69
+ device: Union[str, torch.device] = "cuda",
70
+ ) -> "Text2SignPipeline":
71
+ """
72
+ Load pipeline from a saved checkpoint
73
+
74
+ Args:
75
+ checkpoint_path: Path to the checkpoint file
76
+ device: Device to load models on
77
+
78
+ Returns:
79
+ Text2SignPipeline instance
80
+ """
81
+ checkpoint = torch.load(checkpoint_path, map_location=device)
82
+
83
+ # Get configs from checkpoint
84
+ model_config = checkpoint.get("model_config", ModelConfig())
85
+ ddim_config = checkpoint.get("ddim_config", DDIMConfig())
86
+ generation_config = GenerationConfig()
87
+
88
+ # Handle dataclass or dict
89
+ if isinstance(model_config, dict):
90
+ model_config = ModelConfig(**model_config)
91
+ if isinstance(ddim_config, dict):
92
+ ddim_config = DDIMConfig(**ddim_config)
93
+
94
+ # Detect actual transformer_depth from model weights (config may be wrong)
95
+ state_dict = checkpoint["model_state_dict"]
96
+ actual_transformer_depth = 1
97
+ for key in state_dict.keys():
98
+ if 'spatial_blocks.' in key:
99
+ idx = int(key.split('spatial_blocks.')[1].split('.')[0])
100
+ actual_transformer_depth = max(actual_transformer_depth, idx + 1)
101
+
102
+ config_depth = getattr(model_config, 'transformer_depth', 1)
103
+ if config_depth != actual_transformer_depth:
104
+ print(f" Note: Config says transformer_depth={config_depth}, but weights have depth={actual_transformer_depth}")
105
+ print(f" Using actual depth from weights: {actual_transformer_depth}")
106
+
107
+ # Create models with all transformer parameters from config
108
+ model = UNet3D(
109
+ in_channels=model_config.in_channels,
110
+ model_channels=model_config.model_channels,
111
+ out_channels=model_config.in_channels,
112
+ num_res_blocks=model_config.num_res_blocks,
113
+ attention_resolutions=model_config.attention_resolutions,
114
+ channel_mult=model_config.channel_mult,
115
+ num_heads=model_config.num_heads,
116
+ context_dim=model_config.context_dim,
117
+ use_transformer=getattr(model_config, 'use_transformer', True),
118
+ transformer_depth=actual_transformer_depth, # Use detected depth from weights
119
+ use_gradient_checkpointing=getattr(model_config, 'use_gradient_checkpointing', False),
120
+ )
121
+
122
+ # Detect text encoder type from weights
123
+ text_encoder_state_dict = checkpoint["text_encoder_state_dict"]
124
+ use_clip = getattr(model_config, "use_clip_text_encoder", False)
125
+
126
+ # Check if weights match CLIP structure
127
+ has_clip_keys = any("model.text_model" in k for k in text_encoder_state_dict.keys())
128
+ has_custom_keys = any("token_embedding.weight" in k and "model.text_model" not in k for k in text_encoder_state_dict.keys())
129
+
130
+ if use_clip and not has_clip_keys and has_custom_keys:
131
+ print(" Note: Config says use_clip_text_encoder=True, but weights appear to be custom TextEncoder")
132
+ print(" Forcing use_clip=False")
133
+ use_clip = False
134
+ # Update config to match
135
+ model_config.use_clip_text_encoder = False
136
+
137
+ text_encoder = create_text_encoder(
138
+ model_config,
139
+ use_clip=use_clip,
140
+ )
141
+
142
+ scheduler = DDIMScheduler(
143
+ num_train_timesteps=ddim_config.num_train_timesteps,
144
+ beta_start=ddim_config.beta_start,
145
+ beta_end=ddim_config.beta_end,
146
+ beta_schedule=ddim_config.beta_schedule,
147
+ clip_sample=ddim_config.clip_sample,
148
+ prediction_type=ddim_config.prediction_type,
149
+ )
150
+
151
+ # Load weights
152
+ model.load_state_dict(checkpoint["model_state_dict"])
153
+ text_encoder.load_state_dict(checkpoint["text_encoder_state_dict"])
154
+
155
+ return cls(
156
+ model=model,
157
+ text_encoder=text_encoder,
158
+ scheduler=scheduler,
159
+ model_config=model_config,
160
+ generation_config=generation_config,
161
+ device=device,
162
+ )
163
+
164
+ @torch.no_grad()
165
+ def __call__(
166
+ self,
167
+ prompt: Union[str, List[str]],
168
+ num_inference_steps: Optional[int] = None,
169
+ guidance_scale: Optional[float] = None,
170
+ eta: Optional[float] = None,
171
+ generator: Optional[torch.Generator] = None,
172
+ output_type: str = "pil", # "pil", "tensor", "numpy"
173
+ ) -> Union[List[List[Image.Image]], torch.Tensor, np.ndarray]:
174
+ """
175
+ Generate sign language video from text prompt
176
+
177
+ Args:
178
+ prompt: Text prompt or list of prompts
179
+ num_inference_steps: Number of denoising steps
180
+ guidance_scale: Classifier-free guidance scale
181
+ eta: Stochasticity parameter (0 = deterministic DDIM)
182
+ generator: Random generator for reproducibility
183
+ output_type: Type of output ("pil", "tensor", "numpy")
184
+
185
+ Returns:
186
+ Generated videos in requested format
187
+ """
188
+ # Handle single prompt
189
+ if isinstance(prompt, str):
190
+ prompt = [prompt]
191
+
192
+ batch_size = len(prompt)
193
+
194
+ # Use default values if not specified
195
+ if num_inference_steps is None:
196
+ num_inference_steps = self.generation_config.num_inference_steps
197
+ if guidance_scale is None:
198
+ guidance_scale = self.generation_config.guidance_scale
199
+ if eta is None:
200
+ eta = self.generation_config.eta
201
+
202
+ # Tokenize prompts
203
+ if self.use_clip_text_encoder:
204
+ text_embeddings = self.text_encoder(tokens=None, text=prompt)
205
+ else:
206
+ tokens = self.tokenizer(prompt).to(self.device)
207
+ text_embeddings = self.text_encoder(tokens)
208
+
209
+ # For classifier-free guidance
210
+ if guidance_scale > 1.0:
211
+ if self.use_clip_text_encoder:
212
+ uncond_embeddings = self.text_encoder(tokens=None, text=[""] * batch_size)
213
+ else:
214
+ uncond_tokens = self.tokenizer([""] * batch_size).to(self.device)
215
+ uncond_embeddings = self.text_encoder(uncond_tokens)
216
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
217
+
218
+ # Set inference timesteps
219
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
220
+
221
+ # Initialize latents
222
+ latents_shape = (
223
+ batch_size,
224
+ self.model_config.in_channels,
225
+ self.model_config.num_frames,
226
+ self.model_config.image_size,
227
+ self.model_config.image_size,
228
+ )
229
+
230
+ if generator is not None:
231
+ latents = torch.randn(latents_shape, generator=generator, device=self.device)
232
+ else:
233
+ latents = torch.randn(latents_shape, device=self.device)
234
+
235
+ # Denoising loop
236
+ for t in tqdm(self.scheduler.timesteps, desc="Generating sign language", leave=True):
237
+ latent_model_input = latents
238
+
239
+ if guidance_scale > 1.0:
240
+ latent_model_input = torch.cat([latents] * 2)
241
+
242
+ timestep = torch.tensor([t] * latent_model_input.shape[0], device=self.device)
243
+
244
+ # Predict noise
245
+ noise_pred = self.model(latent_model_input, timestep, text_embeddings)
246
+
247
+ # Apply classifier-free guidance
248
+ if guidance_scale > 1.0:
249
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
250
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
251
+
252
+ # DDIM step
253
+ latents, _ = self.scheduler.step(noise_pred, t, latents, eta=eta, generator=generator)
254
+
255
+ # Denormalize
256
+ videos = (latents + 1) / 2
257
+ videos = videos.clamp(0, 1)
258
+
259
+ # Convert to output type
260
+ if output_type == "tensor":
261
+ return videos
262
+ elif output_type == "numpy":
263
+ return videos.cpu().numpy()
264
+ else: # "pil"
265
+ return self._tensor_to_pil(videos)
266
+
267
+ def _tensor_to_pil(self, videos: torch.Tensor) -> List[List[Image.Image]]:
268
+ """Convert tensor videos to PIL images"""
269
+ # videos: (B, C, T, H, W)
270
+ videos = videos.cpu().numpy()
271
+
272
+ all_videos = []
273
+ for video in videos:
274
+ # (C, T, H, W) -> (T, H, W, C)
275
+ frames = video.transpose(1, 2, 3, 0)
276
+ frames = (frames * 255).astype(np.uint8)
277
+
278
+ pil_frames = [Image.fromarray(frame) for frame in frames]
279
+ all_videos.append(pil_frames)
280
+
281
+ return all_videos
282
+
283
+ def save_gif(
284
+ self,
285
+ frames: List[Image.Image],
286
+ path: str,
287
+ fps: Optional[int] = None,
288
+ ):
289
+ """
290
+ Save frames as GIF
291
+
292
+ Args:
293
+ frames: List of PIL images
294
+ path: Output path
295
+ fps: Frames per second
296
+ """
297
+ if fps is None:
298
+ fps = self.generation_config.fps
299
+
300
+ duration = 1000 // fps
301
+
302
+ frames[0].save(
303
+ path,
304
+ save_all=True,
305
+ append_images=frames[1:],
306
+ duration=duration,
307
+ loop=0,
308
+ )
309
+
310
+ def generate_and_save(
311
+ self,
312
+ prompt: Union[str, List[str]],
313
+ output_dir: str,
314
+ prefix: str = "sign",
315
+ **kwargs,
316
+ ) -> List[str]:
317
+ """
318
+ Generate and save GIFs
319
+
320
+ Args:
321
+ prompt: Text prompt(s)
322
+ output_dir: Directory to save GIFs
323
+ prefix: Filename prefix
324
+ **kwargs: Arguments passed to __call__
325
+
326
+ Returns:
327
+ List of saved file paths
328
+ """
329
+ os.makedirs(output_dir, exist_ok=True)
330
+
331
+ if isinstance(prompt, str):
332
+ prompt = [prompt]
333
+
334
+ videos = self(prompt, **kwargs)
335
+
336
+ saved_paths = []
337
+ for i, (frames, text) in enumerate(zip(videos, prompt)):
338
+ # Create filename from prompt
339
+ safe_text = "".join(c if c.isalnum() else "_" for c in text[:30])
340
+ filename = f"{prefix}_{i}_{safe_text}.gif"
341
+ filepath = os.path.join(output_dir, filename)
342
+
343
+ self.save_gif(frames, filepath)
344
+ saved_paths.append(filepath)
345
+ print(f"Saved: {filepath}")
346
+
347
+ return saved_paths
348
+
349
+
350
+ def create_pipeline(
351
+ model_config: Optional[ModelConfig] = None,
352
+ ddim_config: Optional[DDIMConfig] = None,
353
+ generation_config: Optional[GenerationConfig] = None,
354
+ device: str = "cuda",
355
+ ) -> Text2SignPipeline:
356
+ """
357
+ Create a new pipeline with untrained models
358
+ (useful for testing)
359
+ """
360
+ if model_config is None:
361
+ model_config = ModelConfig()
362
+ if ddim_config is None:
363
+ ddim_config = DDIMConfig()
364
+ if generation_config is None:
365
+ generation_config = GenerationConfig()
366
+
367
+ model = UNet3D(
368
+ in_channels=model_config.in_channels,
369
+ model_channels=model_config.model_channels,
370
+ out_channels=model_config.in_channels,
371
+ num_res_blocks=model_config.num_res_blocks,
372
+ attention_resolutions=model_config.attention_resolutions,
373
+ channel_mult=model_config.channel_mult,
374
+ num_heads=model_config.num_heads,
375
+ context_dim=model_config.context_dim,
376
+ )
377
+
378
+ text_encoder = create_text_encoder(
379
+ model_config,
380
+ use_clip=getattr(model_config, "use_clip_text_encoder", False),
381
+ )
382
+
383
+ scheduler = DDIMScheduler(
384
+ num_train_timesteps=ddim_config.num_train_timesteps,
385
+ beta_start=ddim_config.beta_start,
386
+ beta_end=ddim_config.beta_end,
387
+ beta_schedule=ddim_config.beta_schedule,
388
+ clip_sample=ddim_config.clip_sample,
389
+ prediction_type=ddim_config.prediction_type,
390
+ )
391
+
392
+ return Text2SignPipeline(
393
+ model=model,
394
+ text_encoder=text_encoder,
395
+ scheduler=scheduler,
396
+ model_config=model_config,
397
+ generation_config=generation_config,
398
+ device=device,
399
+ )
400
+
401
+
402
+ if __name__ == "__main__":
403
+ # Test pipeline
404
+ print("Creating pipeline...")
405
+ pipeline = create_pipeline(device="cpu")
406
+
407
+ print("Testing generation...")
408
+ videos = pipeline(
409
+ ["Hello", "Thank you"],
410
+ num_inference_steps=5,
411
+ guidance_scale=3.0,
412
+ )
413
+
414
+ print(f"Generated {len(videos)} videos")
415
+ print(f"Each video has {len(videos[0])} frames")
416
+ print(f"Frame size: {videos[0][0].size}")
schedulers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Schedulers package for text-to-sign language generation
3
+ """
4
+
5
+ from .ddim import DDIMScheduler, get_ddim_scheduler
6
+
7
+ __all__ = [
8
+ "DDIMScheduler",
9
+ "get_ddim_scheduler",
10
+ ]
schedulers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (342 Bytes). View file
 
schedulers/__pycache__/ddim.cpython-310.pyc ADDED
Binary file (7.92 kB). View file
 
schedulers/ddim.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DDIM (Denoising Diffusion Implicit Models) Scheduler
3
+ Implements both training and sampling procedures
4
+ """
5
+
6
+ import math
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+
13
+
14
+ class DDIMScheduler:
15
+ """
16
+ DDIM Scheduler for diffusion models
17
+
18
+ Supports both DDPM training and DDIM deterministic/stochastic sampling
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ num_train_timesteps: int = 1000,
24
+ beta_start: float = 0.0001,
25
+ beta_end: float = 0.02,
26
+ beta_schedule: str = "linear",
27
+ clip_sample: bool = True,
28
+ prediction_type: str = "epsilon",
29
+ thresholding: bool = False,
30
+ dynamic_thresholding_ratio: float = 0.995,
31
+ sample_max_value: float = 1.0,
32
+ ):
33
+ """
34
+ Args:
35
+ num_train_timesteps: Number of diffusion steps
36
+ beta_start: Starting beta value
37
+ beta_end: Ending beta value
38
+ beta_schedule: Type of beta schedule ("linear" or "cosine")
39
+ clip_sample: Whether to clip predicted samples
40
+ prediction_type: What the model predicts ("epsilon" or "v_prediction")
41
+ thresholding: Whether to use dynamic thresholding
42
+ dynamic_thresholding_ratio: Ratio for dynamic thresholding
43
+ sample_max_value: Max value for clipping
44
+ """
45
+ self.num_train_timesteps = num_train_timesteps
46
+ self.beta_start = beta_start
47
+ self.beta_end = beta_end
48
+ self.beta_schedule = beta_schedule
49
+ self.clip_sample = clip_sample
50
+ self.prediction_type = prediction_type
51
+ self.thresholding = thresholding
52
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
53
+ self.sample_max_value = sample_max_value
54
+
55
+ # Compute betas
56
+ if beta_schedule == "linear":
57
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps)
58
+ elif beta_schedule == "cosine":
59
+ self.betas = self._cosine_beta_schedule(num_train_timesteps)
60
+ elif beta_schedule == "squaredcos_cap_v2":
61
+ self.betas = self._squaredcos_cap_v2_schedule(num_train_timesteps)
62
+ else:
63
+ raise ValueError(f"Unknown beta schedule: {beta_schedule}")
64
+
65
+ # Compute alphas
66
+ self.alphas = 1.0 - self.betas
67
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
68
+ self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
69
+
70
+ # Calculations for diffusion q(x_t | x_{t-1})
71
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
72
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
73
+
74
+ # Calculations for posterior q(x_{t-1} | x_t, x_0)
75
+ self.posterior_variance = (
76
+ self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
77
+ )
78
+ self.posterior_log_variance_clipped = torch.log(
79
+ torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]])
80
+ )
81
+ self.posterior_mean_coef1 = (
82
+ self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
83
+ )
84
+ self.posterior_mean_coef2 = (
85
+ (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
86
+ )
87
+
88
+ # For sampling
89
+ self.num_inference_steps = None
90
+ self.timesteps = None
91
+
92
+ def _cosine_beta_schedule(self, timesteps: int, s: float = 0.008) -> torch.Tensor:
93
+ """Cosine schedule as proposed in https://arxiv.org/abs/2102.09672"""
94
+ steps = timesteps + 1
95
+ x = torch.linspace(0, timesteps, steps)
96
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
97
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
98
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
99
+ return torch.clip(betas, 0.0001, 0.9999)
100
+
101
+ def _squaredcos_cap_v2_schedule(self, timesteps: int) -> torch.Tensor:
102
+ """Squared cosine schedule used in improved DDPM"""
103
+ return self._cosine_beta_schedule(timesteps)
104
+
105
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = "cpu"):
106
+ """
107
+ Set the timesteps for inference
108
+
109
+ Args:
110
+ num_inference_steps: Number of steps for inference
111
+ device: Device to put tensors on
112
+ """
113
+ self.num_inference_steps = num_inference_steps
114
+
115
+ # DDIM uses uniform spacing
116
+ step_ratio = self.num_train_timesteps // num_inference_steps
117
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
118
+ self.timesteps = torch.from_numpy(timesteps).to(device)
119
+
120
+ def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
121
+ """Compute variance for given timestep"""
122
+ alpha_prod_t = self.alphas_cumprod[timestep]
123
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else torch.tensor(1.0)
124
+
125
+ beta_prod_t = 1 - alpha_prod_t
126
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
127
+
128
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
129
+
130
+ return variance
131
+
132
+ def add_noise(
133
+ self,
134
+ original_samples: torch.Tensor,
135
+ noise: torch.Tensor,
136
+ timesteps: torch.Tensor,
137
+ ) -> torch.Tensor:
138
+ """
139
+ Add noise to samples for training
140
+
141
+ Args:
142
+ original_samples: Clean samples x_0
143
+ noise: Noise to add
144
+ timesteps: Timesteps for each sample
145
+
146
+ Returns:
147
+ Noisy samples x_t
148
+ """
149
+ # Move coefficients to correct device and dtype
150
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(original_samples.device)
151
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(original_samples.device)
152
+
153
+ sqrt_alpha_prod = sqrt_alphas_cumprod[timesteps]
154
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alphas_cumprod[timesteps]
155
+
156
+ # Reshape for broadcasting
157
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
158
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
159
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
160
+
161
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
162
+
163
+ return noisy_samples
164
+
165
+ def step(
166
+ self,
167
+ model_output: torch.Tensor,
168
+ timestep: int,
169
+ sample: torch.Tensor,
170
+ eta: float = 0.0,
171
+ generator: Optional[torch.Generator] = None,
172
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
173
+ """
174
+ Perform one DDIM denoising step
175
+
176
+ Args:
177
+ model_output: Output from the model (predicted noise or v)
178
+ timestep: Current timestep
179
+ sample: Current noisy sample x_t
180
+ eta: Stochasticity factor (0 = deterministic DDIM, 1 = DDPM)
181
+ generator: Random generator for reproducibility
182
+
183
+ Returns:
184
+ Tuple of (predicted x_{t-1}, predicted x_0)
185
+ """
186
+ # Get previous timestep
187
+ prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
188
+
189
+ # Get alpha values
190
+ alpha_prod_t = self.alphas_cumprod[timestep]
191
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else torch.tensor(1.0)
192
+
193
+ beta_prod_t = 1 - alpha_prod_t
194
+
195
+ # Compute predicted x_0
196
+ if self.prediction_type == "epsilon":
197
+ pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
198
+ elif self.prediction_type == "v_prediction":
199
+ pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
200
+ else:
201
+ raise ValueError(f"Unknown prediction type: {self.prediction_type}")
202
+
203
+ # Clip predicted x_0
204
+ if self.clip_sample:
205
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
206
+
207
+ # Compute variance
208
+ variance = self._get_variance(timestep, prev_timestep)
209
+ std_dev_t = eta * variance ** 0.5
210
+
211
+ # Compute direction pointing to x_t
212
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** 0.5 * model_output
213
+
214
+ # Compute x_{t-1}
215
+ prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
216
+
217
+ # Add noise if eta > 0
218
+ if eta > 0:
219
+ device = model_output.device
220
+ noise = torch.randn(
221
+ model_output.shape,
222
+ generator=generator,
223
+ device=device,
224
+ dtype=model_output.dtype
225
+ )
226
+ prev_sample = prev_sample + std_dev_t * noise
227
+
228
+ return prev_sample, pred_original_sample
229
+
230
+ def get_velocity(
231
+ self,
232
+ sample: torch.Tensor,
233
+ noise: torch.Tensor,
234
+ timesteps: torch.Tensor,
235
+ ) -> torch.Tensor:
236
+ """
237
+ Compute velocity for v-prediction
238
+
239
+ v = sqrt(alpha_t) * noise - sqrt(1 - alpha_t) * sample
240
+ """
241
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(sample.device)
242
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(sample.device)
243
+
244
+ sqrt_alpha_prod = sqrt_alphas_cumprod[timesteps]
245
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alphas_cumprod[timesteps]
246
+
247
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
248
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
249
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
250
+
251
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
252
+
253
+ return velocity
254
+
255
+
256
+ # Import F for F.pad
257
+ import torch.nn.functional as F
258
+
259
+
260
+ def get_ddim_scheduler(config) -> DDIMScheduler:
261
+ """Create DDIM scheduler from config"""
262
+ return DDIMScheduler(
263
+ num_train_timesteps=config.num_train_timesteps,
264
+ beta_start=config.beta_start,
265
+ beta_end=config.beta_end,
266
+ beta_schedule=config.beta_schedule,
267
+ clip_sample=config.clip_sample,
268
+ prediction_type=config.prediction_type,
269
+ )
270
+
271
+
272
+ if __name__ == "__main__":
273
+ # Test the scheduler
274
+ scheduler = DDIMScheduler(
275
+ num_train_timesteps=1000,
276
+ beta_start=0.0001,
277
+ beta_end=0.02,
278
+ beta_schedule="linear",
279
+ )
280
+
281
+ # Test adding noise
282
+ x = torch.randn(2, 3, 16, 64, 64)
283
+ noise = torch.randn_like(x)
284
+ timesteps = torch.tensor([100, 500])
285
+
286
+ noisy_x = scheduler.add_noise(x, noise, timesteps)
287
+ print(f"Original shape: {x.shape}")
288
+ print(f"Noisy shape: {noisy_x.shape}")
289
+
290
+ # Test sampling
291
+ scheduler.set_timesteps(50)
292
+ print(f"Inference timesteps: {scheduler.timesteps[:10]}...")
293
+
294
+ # Test step
295
+ model_output = torch.randn_like(x)
296
+ prev_sample, pred_x0 = scheduler.step(model_output, 500, noisy_x, eta=0.0)
297
+ print(f"Previous sample shape: {prev_sample.shape}")
298
+ print(f"Predicted x0 shape: {pred_x0.shape}")