File size: 8,197 Bytes
4757a21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Tuple, Dict, Any
import math


class AdvancedBPETokenizer:
    """Advanced BPE tokenizer that's actually pretty smart! 🤓
    
    Not your basic tokenizer - this one understands context and can handle
    few-shot learning scenarios. It's like having a linguist and a mathematician
    team up to break down text.
    """
    
    def __init__(self, vocab_size: int = 32000):
        self.vocab_size = vocab_size
        self.vocab = self._build_advanced_vocab()
        self.encode_dict = {v: k for k, v in enumerate(self.vocab)}
        self.decode_dict = {k: v for k, v in enumerate(self.vocab)}
        
        # Special tokens for few-shot learning
        self.special_tokens = {
            '<|support|>': vocab_size - 4,
            '<|query|>': vocab_size - 3,
            '<|adapt|>': vocab_size - 2,
            '<|eos|>': vocab_size - 1
        }
    
    def _build_advanced_vocab(self):
        """Build advanced vocabulary with subword units."""
        vocab = []
        
        # Byte-level tokens
        for i in range(256):
            vocab.append(f"<|byte_{i}|>")
        
        # Common subwords (simplified BPE)
        common_subwords = [
            'ing', 'ed', 'er', 'est', 'ly', 'tion', 'ment', 'ness',
            'ful', 'less', 'able', 'ible', 'pre', 'un', 're', 'de'
        ]
        vocab.extend(common_subwords)
        
        # Fill remaining with generated tokens
        while len(vocab) < self.vocab_size - 4:  # Reserve 4 for special tokens
            vocab.append(f"<|token_{len(vocab)}|>")
        
        return vocab[:self.vocab_size - 4]
    
    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """Advanced encoding with subword support."""
        if add_special_tokens:
            text = '<|support|>' + text + '<|eos|>'
        
        # Simple byte-level encoding (can be enhanced with proper BPE)
        tokens = []
        for char in text.encode('utf-8'):
            if char < 256:
                tokens.append(char)
            else:
                tokens.append(0)  # UNK
        
        return tokens
    
    def decode(self, tokens: List[int]) -> str:
        """Advanced decoding."""
        try:
            # Filter out special tokens
            filtered_tokens = [t for t in tokens if t < 256]
            return bytes(filtered_tokens).decode('utf-8', errors='ignore')
        except:
            return "".join([f"<{token}>" for token in tokens])


class ModelProfiler:
    """The detective of model performance! 🔍
    
    This class pokes and prods your model to figure out how fast it runs,
    how much memory it gobbles up, and other juicy performance details.
    Perfect for when you need to brag about your model's speed!
    """
    
    @staticmethod
    def get_model_stats(model) -> Dict[str, Any]:
        """Get comprehensive model statistics."""
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        return {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'model_size_mb': total_params * 2 / 1e6,  # FP16
            'architecture': 'Hyper Mamba',
            'features': [
                'Meta-Learning',
                'Neuro-Symbolic',
                'Knowledge Distillation',
                'Progressive Learning',
                'Few-Shot Adaptation',
                'Continual Learning'
            ]
        }
    
    @staticmethod
    def benchmark_inference(model, input_ids: torch.Tensor, num_runs: int = 10):
        """Benchmark inference speed."""
        import time
        
        model.eval()
        times = []
        
        # Warmup
        with torch.no_grad():
            for _ in range(3):
                _ = model(input_ids)
        
        # Actual benchmark
        with torch.no_grad():
            for _ in range(num_runs):
                start_time = time.time()
                _ = model(input_ids)
                end_time = time.time()
                times.append(end_time - start_time)
        
        avg_time = sum(times) / len(times)
        batch_size, seq_len = input_ids.shape
        
        return {
            'avg_time_ms': avg_time * 1000,
            'throughput_tokens_per_sec': batch_size * seq_len / avg_time,
            'batch_size': batch_size,
            'sequence_length': seq_len
        }


class FewShotDataLoader:
    """Data loader that sets up few-shot learning like a pro! 🎯
    
    Takes your messy data and organizes it into neat support/query sets.
    It's like having a personal assistant who knows exactly how to arrange
    examples for maximum learning efficiency.
    """
    
    def __init__(self, support_size: int = 5, query_size: int = 10):
        self.support_size = support_size
        self.query_size = query_size
    
    def create_few_shot_batch(self, texts: List[str], tokenizer) -> Dict[str, torch.Tensor]:
        """Create few-shot learning batch."""
        # Encode texts
        encoded = [tokenizer.encode(text) for text in texts]
        
        # Split into support and query
        support_examples = encoded[:self.support_size]
        query_examples = encoded[self.support_size:self.support_size + self.query_size]
        
        # Pad sequences
        max_len = max(max(len(seq) for seq in support_examples), 
                     max(len(seq) for seq in query_examples))
        
        def pad_sequence(seq, max_len):
            return seq + [0] * (max_len - len(seq))
        
        support_tensor = torch.tensor([pad_sequence(seq, max_len) for seq in support_examples])
        query_tensor = torch.tensor([pad_sequence(seq, max_len) for seq in query_examples])
        
        return {
            'support_set': support_tensor,
            'query_set': query_tensor,
            'support_size': self.support_size,
            'query_size': self.query_size
        }


class VisualizationUtils:
    """Visualization tools cho model analysis."""
    
    @staticmethod
    def plot_attention_weights(attention_weights: torch.Tensor, tokens: List[str]):
        """Plot attention weights heatmap."""
        try:
            import matplotlib.pyplot as plt
            import seaborn as sns
            
            plt.figure(figsize=(10, 8))
            sns.heatmap(
                attention_weights.cpu().numpy(),
                xticklabels=tokens,
                yticklabels=tokens,
                cmap='Blues',
                annot=True,
                fmt='.2f'
            )
            plt.title('Attention Weights Visualization')
            plt.xlabel('Key Tokens')
            plt.ylabel('Query Tokens')
            plt.tight_layout()
            plt.show()
        except ImportError:
            print("Matplotlib/Seaborn not available for visualization")
    
    @staticmethod
    def analyze_layer_activations(model, input_ids: torch.Tensor):
        """Analyze activations across layers."""
        activations = []
        
        def hook_fn(module, input, output):
            activations.append(output.detach().cpu())
        
        # Register hooks
        hooks = []
        for layer in model.layers:
            hook = layer.register_forward_hook(hook_fn)
            hooks.append(hook)
        
        # Forward pass
        with torch.no_grad():
            _ = model(input_ids)
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        # Analyze activations
        stats = []
        for i, activation in enumerate(activations):
            stats.append({
                'layer': i,
                'mean': activation.mean().item(),
                'std': activation.std().item(),
                'max': activation.max().item(),
                'min': activation.min().item()
            })
        
        return stats


# Export all utilities
__all__ = [
    'AdvancedBPETokenizer',
    'ModelProfiler', 
    'FewShotDataLoader',
    'VisualizationUtils'
]