File size: 8,400 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
"""

Example: Using BitLinear as a drop-in replacement for nn.Linear in a Transformer.



This example demonstrates:

    1. Creating a simple Transformer block with standard nn.Linear

    2. Converting it to use BitLinear layers

    3. Running forward passes to verify compatibility

    4. Comparing memory usage and output similarity

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear


class TransformerBlock(nn.Module):
    """

    Simplified Transformer block for demonstration.

    

    Contains:

        - Multi-head self-attention with linear projections

        - Feed-forward network with two linear layers

        - Layer normalization and residual connections

    """
    
    def __init__(

        self,

        d_model: int = 512,

        nhead: int = 8,

        dim_feedforward: int = 2048,

        dropout: float = 0.1,

    ):
        super().__init__()
        
        # Multi-head attention components
        self.d_model = d_model
        self.nhead = nhead
        self.d_k = d_model // nhead
        
        # Linear projections for Q, K, V
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(

        self,

        x: torch.Tensor,

        mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        """

        Forward pass through Transformer block.

        

        Args:

            x: Input tensor [batch_size, seq_len, d_model]

            mask: Optional attention mask

        

        Returns:

            Output tensor [batch_size, seq_len, d_model]

        """
        # Multi-head self-attention
        residual = x
        x = self.norm1(x)
        
        # Compute Q, K, V
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        # Reshape for multi-head attention
        batch_size, seq_len, _ = x.shape
        q = q.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)
        
        # Reshape and project back
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        attn_output = self.out_proj(attn_output)
        attn_output = self.dropout1(attn_output)
        
        # First residual connection
        x = residual + attn_output
        
        # Feed-forward network
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = self.dropout2(x)
        
        # Second residual connection
        x = residual + x
        
        return x


def count_parameters(model: nn.Module) -> int:
    """Count total trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def estimate_memory_mb(model: nn.Module) -> float:
    """Estimate memory usage of model parameters in MB."""
    total_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
    return total_bytes / (1024 ** 2)


def compare_outputs(

    output1: torch.Tensor,

    output2: torch.Tensor,

) -> dict:
    """

    Compare two output tensors and compute similarity metrics.

    

    Returns:

        Dictionary with comparison metrics

    """
    mse = F.mse_loss(output1, output2).item()
    cosine_sim = F.cosine_similarity(
        output1.flatten(), output2.flatten(), dim=0
    ).item()
    relative_error = (
        torch.norm(output1 - output2) / torch.norm(output1)
    ).item()
    
    return {
        "mse": mse,
        "cosine_similarity": cosine_sim,
        "relative_error": relative_error,
    }


def main():
    """Main example demonstrating BitLinear usage in Transformer."""
    
    print("=" * 80)
    print("BitLinear Transformer Example")
    print("=" * 80)
    
    # Configuration
    batch_size = 32
    seq_len = 128
    d_model = 512
    nhead = 8
    dim_feedforward = 2048
    
    # Create input
    x = torch.randn(batch_size, seq_len, d_model)
    print(f"\nInput shape: {x.shape}")
    
    # 1. Create standard Transformer block
    print("\n" + "-" * 80)
    print("1. Standard Transformer with nn.Linear")
    print("-" * 80)
    
    model_standard = TransformerBlock(
        d_model=d_model,
        nhead=nhead,
        dim_feedforward=dim_feedforward,
    )
    
    print(f"Parameters: {count_parameters(model_standard):,}")
    print(f"Memory: {estimate_memory_mb(model_standard):.2f} MB")
    
    # Forward pass
    with torch.no_grad():
        output_standard = model_standard(x)
    print(f"Output shape: {output_standard.shape}")
    
    # 2. Convert to BitLinear
    print("\n" + "-" * 80)
    print("2. Transformer with BitLinear")
    print("-" * 80)
    
    model_bitlinear = convert_linear_to_bitlinear(model_standard, inplace=False)
    
    print(f"Parameters: {count_parameters(model_bitlinear):,}")
    print(f"Memory: {estimate_memory_mb(model_bitlinear):.2f} MB")
    
    # Forward pass
    with torch.no_grad():
        output_bitlinear = model_bitlinear(x)
    print(f"Output shape: {output_bitlinear.shape}")
    
    # 3. Compare outputs
    print("\n" + "-" * 80)
    print("3. Output Comparison")
    print("-" * 80)
    
    metrics = compare_outputs(output_standard, output_bitlinear)
    print(f"MSE: {metrics['mse']:.6f}")
    print(f"Cosine similarity: {metrics['cosine_similarity']:.6f}")
    print(f"Relative error: {metrics['relative_error']:.6f}")
    
    # 4. Memory savings
    print("\n" + "-" * 80)
    print("4. Memory Savings")
    print("-" * 80)
    
    mem_standard = estimate_memory_mb(model_standard)
    mem_bitlinear = estimate_memory_mb(model_bitlinear)
    savings = (mem_standard - mem_bitlinear) / mem_standard * 100
    
    print(f"Standard model: {mem_standard:.2f} MB")
    print(f"BitLinear model: {mem_bitlinear:.2f} MB")
    print(f"Memory savings: {savings:.1f}%")
    print(f"Compression ratio: {mem_standard / mem_bitlinear:.1f}x")
    
    # 5. Count Linear layers converted
    print("\n" + "-" * 80)
    print("5. Conversion Details")
    print("-" * 80)
    
    def count_linear_layers(model):
        count = 0
        for module in model.modules():
            if isinstance(module, nn.Linear):
                count += 1
        return count
    
    def count_bitlinear_layers(model):
        count = 0
        for module in model.modules():
            if isinstance(module, BitLinear):
                count += 1
        return count
    
    print(f"Original Linear layers: {count_linear_layers(model_standard)}")
    print(f"Converted BitLinear layers: {count_bitlinear_layers(model_bitlinear)}")
    
    print("\n" + "=" * 80)
    print("Example complete!")
    print("=" * 80)
    print("\nKey Takeaways:")
    print("- BitLinear is a drop-in replacement for nn.Linear")
    print("- Significant memory savings (~20x for weights)")
    print("- Output similarity is high (cosine sim > 0.99 typically)")
    print("- Slight accuracy trade-off due to ternary quantization")


if __name__ == "__main__":
    main()