File size: 8,015 Bytes
2c914eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
"""
SAL Utilities Module

Helper functions for SAL operations.
Similarity measures, smoothing, seed loading.
"""

import torch
import json
from typing import Dict, Optional, Any, List, Union
from pathlib import Path


def cosine_similarity(
    a: torch.Tensor,
    b: torch.Tensor,
    dim: int = -1,
    eps: float = 1e-8,
) -> torch.Tensor:
    """
    Compute cosine similarity between tensors.
    
    Args:
        a: First tensor
        b: Second tensor
        dim: Dimension along which to compute similarity
        eps: Small epsilon for numerical stability
        
    Returns:
        Cosine similarity (same shape as input, minus the compared dimension)
    """
    a_norm = a / (a.norm(dim=dim, keepdim=True) + eps)
    b_norm = b / (b.norm(dim=dim, keepdim=True) + eps)
    
    return (a_norm * b_norm).sum(dim=dim)


def exponential_moving_average(
    current: torch.Tensor,
    previous: torch.Tensor,
    alpha: float = 0.1,
) -> torch.Tensor:
    """
    Compute exponential moving average.
    
    EMA = alpha * current + (1 - alpha) * previous
    
    Args:
        current: Current value
        previous: Previous EMA value
        alpha: Smoothing factor (0-1, higher = more weight on current)
        
    Returns:
        Updated EMA
    """
    return alpha * current + (1 - alpha) * previous


class EMA:
    """
    Exponential Moving Average tracker.
    
    Useful for smoothing stability scores and other metrics.
    """
    
    def __init__(self, alpha: float = 0.1, initial: Optional[float] = None):
        """
        Initialize EMA tracker.
        
        Args:
            alpha: Smoothing factor
            initial: Initial value (None = use first update)
        """
        self.alpha = alpha
        self.value = initial
        self.count = 0
    
    def update(self, new_value: float) -> float:
        """
        Update EMA with new value.
        
        Args:
            new_value: New observation
            
        Returns:
            Updated EMA value
        """
        self.count += 1
        
        if self.value is None:
            self.value = new_value
        else:
            self.value = self.alpha * new_value + (1 - self.alpha) * self.value
        
        return self.value
    
    def get(self) -> Optional[float]:
        """Get current EMA value."""
        return self.value
    
    def reset(self) -> None:
        """Reset EMA tracker."""
        self.value = None
        self.count = 0


def load_seed(path: Union[str, Path]) -> Dict[str, Any]:
    """
    Load a semantic seed from JSON file.
    
    Seeds are anchor points in semantic space that help
    maintain identity and coherence.
    
    Args:
        path: Path to seed JSON file
        
    Returns:
        Seed dictionary with embedding and metadata
    """
    path = Path(path)
    
    if not path.exists():
        raise FileNotFoundError(f"Seed file not found: {path}")
    
    with open(path, 'r', encoding='utf-8') as f:
        seed = json.load(f)
    
    # Convert embedding to tensor if present
    if 'embedding' in seed and isinstance(seed['embedding'], list):
        seed['embedding'] = torch.tensor(seed['embedding'])
    
    return seed


def save_seed(
    seed: Dict[str, Any],
    path: Union[str, Path],
) -> None:
    """
    Save a semantic seed to JSON file.
    
    Args:
        seed: Seed dictionary
        path: Output path
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    
    # Convert tensor to list for JSON
    seed_copy = seed.copy()
    if 'embedding' in seed_copy and isinstance(seed_copy['embedding'], torch.Tensor):
        seed_copy['embedding'] = seed_copy['embedding'].tolist()
    
    with open(path, 'w', encoding='utf-8') as f:
        json.dump(seed_copy, f, indent=2, ensure_ascii=False)


def create_seed(
    name: str,
    dimension: int = 768,
    seed_type: str = "random",
    metadata: Optional[Dict] = None,
) -> Dict[str, Any]:
    """
    Create a new semantic seed.
    
    Args:
        name: Seed name
        dimension: Embedding dimension
        seed_type: Type of seed initialization
        metadata: Additional metadata
        
    Returns:
        Seed dictionary
    """
    if seed_type == "random":
        embedding = torch.randn(dimension)
        embedding = embedding / embedding.norm()  # Normalize
    elif seed_type == "zero":
        embedding = torch.zeros(dimension)
    elif seed_type == "ones":
        embedding = torch.ones(dimension) / (dimension ** 0.5)
    else:
        raise ValueError(f"Unknown seed type: {seed_type}")
    
    seed = {
        'name': name,
        'dimension': dimension,
        'type': seed_type,
        'embedding': embedding,
        'metadata': metadata or {},
    }
    
    return seed


def weight_distance(
    weights1: Dict[str, torch.Tensor],
    weights2: Dict[str, torch.Tensor],
    metric: str = "l2",
) -> float:
    """
    Compute distance between two sets of weights.
    
    Args:
        weights1: First weight dictionary
        weights2: Second weight dictionary
        metric: Distance metric ("l2", "l1", "cosine")
        
    Returns:
        Distance value
    """
    total_distance = 0.0
    count = 0
    
    for name in weights1:
        if name not in weights2:
            continue
        
        w1 = weights1[name].flatten().float()
        w2 = weights2[name].flatten().float()
        
        if w1.shape != w2.shape:
            continue
        
        if metric == "l2":
            dist = torch.norm(w1 - w2).item()
        elif metric == "l1":
            dist = torch.abs(w1 - w2).sum().item()
        elif metric == "cosine":
            cos_sim = cosine_similarity(w1.unsqueeze(0), w2.unsqueeze(0)).item()
            dist = 1.0 - cos_sim
        else:
            raise ValueError(f"Unknown metric: {metric}")
        
        total_distance += dist
        count += 1
    
    if count == 0:
        return 0.0
    
    return total_distance / count


def gradient_norm(model: torch.nn.Module) -> float:
    """
    Compute total gradient norm across model.
    
    Args:
        model: Neural network
        
    Returns:
        Total gradient L2 norm
    """
    total_norm = 0.0
    
    for param in model.parameters():
        if param.grad is not None:
            total_norm += param.grad.data.norm(2).item() ** 2
    
    return total_norm ** 0.5


def parameter_count(
    model: torch.nn.Module,
    trainable_only: bool = True,
) -> int:
    """
    Count parameters in model.
    
    Args:
        model: Neural network
        trainable_only: If True, count only trainable parameters
        
    Returns:
        Parameter count
    """
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())


def stability_summary(stability_scores: Dict[str, float]) -> Dict[str, float]:
    """
    Summarize stability scores.
    
    Args:
        stability_scores: Dictionary of parameter names to scores
        
    Returns:
        Summary with mean, std, min, max, and distribution
    """
    if not stability_scores:
        return {
            'mean': 0.0,
            'std': 0.0,
            'min': 0.0,
            'max': 0.0,
            'protected_pct': 0.0,
            'neutral_pct': 0.0,
            'volatile_pct': 0.0,
        }
    
    scores = list(stability_scores.values())
    n = len(scores)
    
    mean = sum(scores) / n
    variance = sum((s - mean) ** 2 for s in scores) / n
    std = variance ** 0.5
    
    protected = sum(1 for s in scores if s > 0.7) / n * 100
    volatile = sum(1 for s in scores if s < 0.3) / n * 100
    neutral = 100 - protected - volatile
    
    return {
        'mean': mean,
        'std': std,
        'min': min(scores),
        'max': max(scores),
        'protected_pct': protected,
        'neutral_pct': neutral,
        'volatile_pct': volatile,
    }