File size: 3,246 Bytes
234a70c
 
 
5e9417b
234a70c
 
 
 
 
 
5e9417b
234a70c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e9417b
234a70c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e9417b
234a70c
 
 
 
 
 
 
5e9417b
234a70c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Configuration for Text-to-Sign Language DDIM Diffusion Model
"""

from dataclasses import dataclass
from typing import Optional, Tuple
import torch


@dataclass
class ModelConfig:
    """Model architecture configuration"""
    # Image/Video dimensions
    image_size: int = 64  # Resize GIFs to 64x64
    num_frames: int = 16  # Number of frames per video
    in_channels: int = 3  # RGB channels
    
    # UNet architecture (increased capacity for better quality)
    model_channels: int = 96  # Increased from 64 for better quality
    channel_mult: Tuple[int, ...] = (1, 2, 4)  # Depth levels
    num_res_blocks: int = 2
    attention_resolutions: Tuple[int, ...] = (8, 16)
    num_heads: int = 6  # Increased from 4 for better attention
    
    # Transformer settings (DiT-style)
    use_transformer: bool = True  # Use enhanced DiT-style transformer blocks
    transformer_depth: int = 2  # Increased from 1 for deeper transformers
    use_gradient_checkpointing: bool = True  # Enable gradient checkpointing for memory savings
    
    # Text encoder
    use_clip_text_encoder: bool = True  # Default to frozen pretrained CLIP text encoder
    text_embed_dim: int = 384  # Increased from 256 for richer text embeddings
    max_text_length: int = 77
    vocab_size: int = 49408  # CLIP vocab size
    
    # Cross attention
    context_dim: int = 384  # Increased from 256 for better cross-attention


@dataclass
class DDIMConfig:
    """DDIM scheduler configuration"""
    num_train_timesteps: int = 100
    num_inference_steps: int = 100
    beta_start: float = 0.0001
    beta_end: float = 0.02
    beta_schedule: str = "linear"  # "linear" or "cosine"
    clip_sample: bool = True
    prediction_type: str = "epsilon"  # "epsilon" or "v_prediction"


@dataclass
class TrainingConfig:
    """Training configuration"""
    # Data
    data_dir: str = "text2sign/training_data"
    batch_size: int = 2  # Reduced from 4 for memory
    num_workers: int = 4
    
    # Training
    num_epochs: int = 150  # Increased for more training
    learning_rate: float = 5e-5  # Reduced from 1e-4 for fine-tuning stability
    weight_decay: float = 0.01
    warmup_steps: int = 500  # Reduced warmup for fine-tuning
    gradient_accumulation_steps: int = 8  # Effective batch size = 16
    max_grad_norm: float = 1.0
    
    # Mixed precision
    use_amp: bool = True
    
    # Checkpointing
    checkpoint_dir: str = "text_to_sign/checkpoints"
    save_every: int = 5  # Save every N epochs
    log_every: int = 100  # Log every N steps
    sample_every: int = 1000  # Generate samples every N steps
    
    # TensorBoard
    log_dir: str = "text_to_sign/logs"
    
    # Device
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


@dataclass
class GenerationConfig:
    """Generation/Inference configuration"""
    num_inference_steps: int = 50
    guidance_scale: float = 7.5
    eta: float = 0.0  # 0 for DDIM, 1 for DDPM
    output_dir: str = "text_to_sign/generated"
    fps: int = 8  # Output GIF frame rate


def get_config():
    """Get all configurations"""
    return {
        "model": ModelConfig(),
        "ddim": DDIMConfig(),
        "training": TrainingConfig(),
        "generation": GenerationConfig(),
    }