File size: 3,502 Bytes
388fd6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Minimal example: Training a Gamma SSM block on sine wave data."""

import torch
import math
import numpy as np
from gamma_space_model import GammaSingleBlock


def generate_sine_wave(seq_len: int = 128, freq: float = 0.1, batch_size: int = 4) -> torch.Tensor:
    """Generate sine wave data.

    

    Args:

        seq_len: Sequence length

        freq: Frequency of the sine wave

        batch_size: Number of samples in batch

        

    Returns:

        Tensor of shape (batch_size, seq_len, 1) with sine wave values

    """
    t = torch.arange(seq_len, dtype=torch.float32).unsqueeze(0).unsqueeze(2)  # (1, seq_len, 1)
    sine_data = torch.sin(2 * math.pi * freq * t)  # (1, seq_len, 1)
    batch = sine_data.repeat(batch_size, 1, 1)  # (batch_size, seq_len, 1)
    return batch


def main():
    # Device setup
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}\n")
    
    # Configuration
    d_model = 1           # Input dimension (sine wave is 1D)
    hidden_dim = 16       # SSM hidden state dimension
    seq_len = 128         # Sequence length
    batch_size = 4        # Batch size
    
    print("=" * 60)
    print("Gamma SSM Block - Minimal Example")
    print("=" * 60)
    print(f"Model dimension (d_model):     {d_model}")
    print(f"Hidden dimension (hidden_dim): {hidden_dim}")
    print(f"Sequence length:               {seq_len}")
    print(f"Batch size:                    {batch_size}")
    print()
    
    # Instantiate block with direct parameters (PyTorch style)
    block = GammaSingleBlock(
        d_model=d_model,
        hidden_dim=hidden_dim,
        delta_t=0.1,              # discretization step
        prenorm=True,             # layer norm before SSM
        residual_scale=1.0,       # residual connection scaling
        dropout=0.0,              # no dropout
    ).to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in block.parameters() if p.requires_grad)
    print(f"Trainable parameters: {total_params}\n")
    
    # Generate sine wave data
    print("Generating sine wave data...")
    x = generate_sine_wave(seq_len=seq_len, freq=0.1, batch_size=batch_size).to(device)
    print(f"Input shape: {x.shape}\n")
    
    # Forward pass
    print("Running forward pass...")
    with torch.no_grad():
        output, final_state = block(x)
    
    print(f"Output shape:     {output.shape}")
    print(f"Final state shape: {final_state.shape}")
    print()
    
    # Show gradient flow (test backprop)
    print("Testing gradient flow...")
    x_train = generate_sine_wave(seq_len=seq_len, freq=0.1, batch_size=batch_size).to(device)
    x_train.requires_grad = True
    
    output, _ = block(x_train)
    loss = output.mean()
    loss.backward()
    
    print(f"Loss: {loss.item():.6f}")
    print(f"Input gradient exists: {x_train.grad is not None}")
    print(f"Model has gradients: {any(p.grad is not None for p in block.parameters())}")
    print()
    
    print("=" * 60)
    print("Example complete! ✓")
    print("=" * 60)
    print("\nNext steps:")
    print("1. Modify block hyperparameters (d_model, hidden_dim, prenorm, etc.)")
    print("2. Train with loss() and optimizer.step() in a loop")
    print("3. Stack multiple blocks for deeper models")
    print("4. Use .state_dict() / .load_state_dict() for model saving")


if __name__ == "__main__":
    main()