File size: 10,577 Bytes
37ed739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
"""
Architectural Analysis for RQ1 - Architectural Interpretability

Purpose: Extract and format raw architectural signals for transparency visualization
Focus: Internal mechanisms (NOT post-hoc feature attribution)

Key differences from SHAP/explainability:
- Preserves per-head, per-layer granularity (no aggregation)
- Captures activation patterns and confidence metrics
- Supports causal intervention (ablation)
- Real-time architectural transparency

Based on PhD proposal RQ1:
"Transform opaque architectural mechanisms into interpretable visual representations"
"""

import torch
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
import logging

logger = logging.getLogger(__name__)


def compute_head_entropy(attention_weights: torch.Tensor) -> float:
    """
    Compute entropy of attention distribution for a single head.

    High entropy = diffuse attention (many tokens attended equally)
    Low entropy = focused attention (few tokens dominate)

    Args:
        attention_weights: [seq_len, seq_len] attention matrix for one head

    Returns:
        Entropy value (bits)
    """
    # Average across query positions to get distribution
    avg_dist = attention_weights.mean(dim=0)

    # Add small epsilon to avoid log(0)
    eps = 1e-10
    avg_dist = avg_dist + eps

    # Compute entropy: -sum(p * log(p))
    entropy = -(avg_dist * torch.log2(avg_dist)).sum().item()

    # Ensure finite value
    entropy = float(np.clip(entropy, 0.0, 1e10))
    if not np.isfinite(entropy):
        entropy = 0.0

    return entropy


def identify_head_role(attention_weights: torch.Tensor, tokens: List[str]) -> str:
    """
    Classify attention head role based on attention patterns.

    Roles:
    - 'positional': Attends primarily to specific positions (diagonal, next-token, etc.)
    - 'delimiter': Focuses on delimiters/special tokens (braces, semicolons, etc.)
    - 'content': Attends to semantic content tokens (identifiers, keywords)
    - 'mixed': No clear specialization

    Args:
        attention_weights: [seq_len, seq_len]
        tokens: List of token strings

    Returns:
        Role classification string
    """
    # Compute statistics
    diagonal_strength = torch.diag(attention_weights).mean().item()
    max_weight = attention_weights.max().item()

    # Simple heuristics (can be refined with more research)
    if diagonal_strength > 0.3:
        return 'positional'

    # Check if attends primarily to delimiters
    delimiter_tokens = {'{', '}', '(', ')', '[', ']', ';', ',', ':'}
    delimiter_indices = [i for i, tok in enumerate(tokens) if tok in delimiter_tokens]

    if delimiter_indices:
        delimiter_attention = attention_weights[:, delimiter_indices].mean().item()
        if delimiter_attention > 0.3:
            return 'delimiter'

    # Check for focused content attention
    if max_weight > 0.5:
        return 'content'

    return 'mixed'


def extract_per_head_attention(
    attention_tensor: torch.Tensor,
    layer_idx: int,
    tokens: List[str]
) -> List[Dict[str, Any]]:
    """
    Extract per-head attention data for a specific layer.

    Args:
        attention_tensor: [num_heads, seq_len, seq_len]
        layer_idx: Layer index
        tokens: Token strings

    Returns:
        List of dicts, one per head
    """
    num_heads = attention_tensor.shape[0]
    heads_data = []

    for head_idx in range(num_heads):
        head_attn = attention_tensor[head_idx]  # [seq_len, seq_len]

        # Clean attention matrix - replace NaN/Inf with 0
        head_attn_np = head_attn.cpu().numpy()
        head_attn_np = np.nan_to_num(head_attn_np, nan=0.0, posinf=1.0, neginf=0.0)
        head_attn_np = np.clip(head_attn_np, 0.0, 1.0)

        # Recompute as tensor for entropy/role calculations
        head_attn_clean = torch.from_numpy(head_attn_np)

        entropy = compute_head_entropy(head_attn_clean)
        max_weight = float(head_attn_np.max())
        if not np.isfinite(max_weight):
            max_weight = 0.0

        role = identify_head_role(head_attn_clean, tokens)

        heads_data.append({
            "head_idx": head_idx,
            "attention_matrix": head_attn_np.tolist(),
            "entropy": entropy,
            "max_weight": max_weight,
            "role": role
        })

    return heads_data


def compute_activation_metrics(
    hidden_states: torch.Tensor,
    prev_hidden_states: Optional[torch.Tensor] = None
) -> Dict[str, float]:
    """
    Compute activation-related metrics for a layer.

    Args:
        hidden_states: [seq_len, hidden_dim] output of layer
        prev_hidden_states: Previous layer hidden states (for drift computation)

    Returns:
        Dict with activation magnitude, entropy, norm, drift
    """
    # Activation magnitude: L2 norm averaged across sequence
    activation_magnitude = torch.norm(hidden_states, dim=-1).mean().item()
    activation_magnitude = float(np.clip(activation_magnitude, -1e10, 1e10))
    if not np.isfinite(activation_magnitude):
        activation_magnitude = 0.0

    # Activation entropy: How varied are the activations?
    flat_activations = hidden_states.flatten()
    # Normalize to probability distribution
    probs = torch.softmax(flat_activations, dim=0)
    activation_entropy = -(probs * torch.log2(probs + 1e-10)).sum().item()
    activation_entropy = float(np.clip(activation_entropy, 0.0, 1e10))
    if not np.isfinite(activation_entropy):
        activation_entropy = 0.0

    # Hidden state norm
    hidden_state_norm = torch.norm(hidden_states).item()
    hidden_state_norm = float(np.clip(hidden_state_norm, -1e10, 1e10))
    if not np.isfinite(hidden_state_norm):
        hidden_state_norm = 0.0

    # Hidden state drift (if previous layer available)
    hidden_state_drift = None
    if prev_hidden_states is not None:
        drift = torch.norm(hidden_states - prev_hidden_states).item()
        drift = float(np.clip(drift, -1e10, 1e10))
        if np.isfinite(drift):
            hidden_state_drift = drift

    return {
        "activation_magnitude": activation_magnitude,
        "activation_entropy": activation_entropy,
        "hidden_state_norm": hidden_state_norm,
        "hidden_state_drift": hidden_state_drift
    }


def extract_architectural_data(
    model_outputs: Dict[str, Any],
    input_tokens: List[str],
    output_tokens: List[str],
    model_config: Dict[str, Any]
) -> Dict[str, Any]:
    """
    Extract complete architectural transparency data for visualization.

    This is the main function that formats all data needed for
    ArchitecturalAttentionExplorer component.

    Args:
        model_outputs: Dict containing 'attentions', 'hidden_states', etc.
        input_tokens: Input token strings
        output_tokens: Generated token strings
        model_config: Model configuration (num_layers, num_heads, etc.)

    Returns:
        Complete architectural data dict
    """
    # Extract attention from model outputs
    # Expected shape: attentions is tuple of [batch, num_heads, seq_len, seq_len]
    attentions = model_outputs.get('attentions', None)
    hidden_states = model_outputs.get('hidden_states', None)

    if attentions is None:
        logger.warning("No attention weights in model outputs")
        return None

    # Process each layer
    layers_data = []
    prev_hidden = None

    num_layers = len(attentions)

    for layer_idx in range(num_layers):
        layer_attn = attentions[layer_idx]  # [batch, num_heads, seq_len, seq_len]

        # Remove batch dimension (assuming batch_size=1)
        if layer_attn.dim() == 4:
            layer_attn = layer_attn[0]  # [num_heads, seq_len, seq_len]

        # Extract per-head attention
        all_tokens = input_tokens + output_tokens
        heads_data = extract_per_head_attention(layer_attn, layer_idx, all_tokens)

        # Compute activation metrics
        activation_metrics = {"activation_magnitude": 0.0, "activation_entropy": 0.0, "hidden_state_norm": 0.0}

        if hidden_states is not None and layer_idx < len(hidden_states):
            current_hidden = hidden_states[layer_idx]
            if current_hidden.dim() == 3:  # [batch, seq_len, hidden_dim]
                current_hidden = current_hidden[0]  # Remove batch

            activation_metrics = compute_activation_metrics(current_hidden, prev_hidden)
            prev_hidden = current_hidden

        # Combine data for this layer
        layer_data = {
            "layer_idx": layer_idx,
            "attention_heads": heads_data,
            **activation_metrics
        }

        layers_data.append(layer_data)

    # Build complete response
    architectural_data = {
        "layers": layers_data,
        "model_info": {
            "num_layers": num_layers,
            "num_heads": model_config.get('num_heads', len(heads_data)),
            "hidden_size": model_config.get('hidden_size', 768),
            "model_name": model_config.get('model_name', 'unknown')
        },
        "input_tokens": input_tokens,
        "output_tokens": output_tokens
    }

    # Optional: Expert routing (for MoE models)
    expert_routing = model_outputs.get('router_logits', None)
    if expert_routing is not None:
        architectural_data["expert_routing"] = extract_expert_routing(expert_routing)

    return architectural_data


def extract_expert_routing(router_logits: torch.Tensor) -> List[Dict[str, Any]]:
    """
    Extract expert routing decisions for MoE models.

    Args:
        router_logits: Router logits from model
            Shape depends on model architecture

    Returns:
        List of routing decisions per layer/token
    """
    # This is model-specific and would need to be adapted
    # For DeepSeek-MoE, CodeLlama-MoE, etc.

    # Placeholder implementation
    routing_data = []

    logger.info("Expert routing extraction not yet implemented for this model")

    return routing_data


def format_for_study_endpoint(
    architectural_data: Dict[str, Any],
    generation_metadata: Dict[str, Any]
) -> Dict[str, Any]:
    """
    Format architectural data for /api/study/analyze endpoint response.

    Args:
        architectural_data: Output from extract_architectural_data()
        generation_metadata: Generation stats (time, tokens, etc.)

    Returns:
        Complete response dict
    """
    return {
        "architectural_data": architectural_data,
        "metadata": generation_metadata,
        "visualization_type": "architectural_transparency",
        "research_context": "RQ1: Architectural Interpretability"
    }