File size: 18,957 Bytes
77bcbf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
"""
Cascade Core - Symbiotic Adapter.

The heart of Cascade's system-agnostic design. The adapter uses Kleene fixed-point
convergence to interpret ANY signal format and convert it to Events.

"It doesn't hook into your system — it becomes part of it."
"""

import time
import json
import re
from typing import Any, Dict, List, Optional, Callable, Type
from dataclasses import dataclass

from cascade.core.event import Event


@dataclass
class SignalPattern:
    """A learned pattern for interpreting signals."""
    pattern_type: str  # 'dict', 'string', 'tensor', 'protobuf', etc.
    component: str
    event_type: str
    extractor: Optional[Callable[[Any], Dict[str, Any]]] = None
    confidence: float = 0.0
    match_count: int = 0


class SymbioticAdapter:
    """
    Self-interpreting adapter that converges to any signal format.
    
    The adapter observes signals from the host system and learns how to
    interpret them through fixed-point iteration. It starts with naive
    interpretations and refines them until stable.
    
    This is the key to Cascade's system-agnostic design:
    - No framework-specific hooks required
    - No configuration needed
    - Feed it ANY signal format, it adapts
    
    Example:
        >>> adapter = SymbioticAdapter()
        >>> 
        >>> # Feed it different signal formats
        >>> adapter.interpret({"loss": 0.5, "epoch": 10})
        >>> adapter.interpret("2024-01-01 12:00:00 ERROR training failed")
        >>> adapter.interpret(torch.tensor([0.1, 0.2, 0.3]))
        >>> 
        >>> # It learns patterns and gets better at interpretation
        >>> print(adapter.learned_patterns)
    """
    
    def __init__(self):
        """Initialize the symbiotic adapter."""
        self._patterns: List[SignalPattern] = []
        self._signal_count = 0
        self._interpretation_cache: Dict[str, SignalPattern] = {}
        
        # Built-in interpreters for common formats
        self._builtin_interpreters = {
            dict: self._interpret_dict,
            str: self._interpret_string,
            list: self._interpret_list,
        }
        
        # Regex patterns for log line parsing
        self._log_patterns = [
            # ISO timestamp with level: "2024-01-01 12:00:00 ERROR message"
            re.compile(r'^(\d{4}-\d{2}-\d{2}[T\s]\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+(\w+)\s+(.*)$'),
            # Simple timestamp: "12:00:00.123 component message"
            re.compile(r'^(\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+(\w+)\s+(.*)$'),
            # Pipe-delimited: "timestamp|level|component|key:value"
            re.compile(r'^([^|]+)\|(\w+)\|(\w+)\|(.*)$'),
        ]
        
        # Metric extraction patterns - ONLY extract real training metrics
        # Be strict to avoid extracting garbage from config lines
        self._metric_patterns = [
            # Standard training metrics with = or : 
            re.compile(r'\b(loss|val_loss|train_loss|accuracy|acc|val_acc|lr|learning_rate|epoch|step|iter|iteration|mfu|tokens_per_sec|samples_per_sec|grad_norm|perplexity|ppl)[=:]\s*([+-]?\d+\.?\d*(?:e[+-]?\d+)?)', re.I),
            # "iter X: loss=Y" format from nanoGPT
            re.compile(r'iter\s+(\d+).*loss[=:]?\s*([+-]?\d+\.?\d*)', re.I),
            # "step X loss Y" format
            re.compile(r'step\s+(\d+).*loss\s*[=:]?\s*([+-]?\d+\.?\d*)', re.I),
        ]
    
    def interpret(self, signal: Any) -> Event:
        """
        Interpret any signal into a Cascade Event.
        
        Uses Kleene fixed-point iteration to converge on the best interpretation.
        
        Args:
            signal: Any signal from the host system
            
        Returns:
            Event: The interpreted event
        """
        self._signal_count += 1
        
        # Get signal type
        signal_type = type(signal)
        
        # Try cached pattern first
        cache_key = self._get_cache_key(signal)
        if cache_key in self._interpretation_cache:
            pattern = self._interpretation_cache[cache_key]
            pattern.match_count += 1
            return self._apply_pattern(signal, pattern)
        
        # Try built-in interpreter
        if signal_type in self._builtin_interpreters:
            event = self._builtin_interpreters[signal_type](signal)
            self._learn_pattern(signal, event)
            return event
        
        # Try tensor-like objects (duck typing)
        if hasattr(signal, 'numpy') or hasattr(signal, 'detach'):
            event = self._interpret_tensor(signal)
            self._learn_pattern(signal, event)
            return event
        
        # Try protobuf-like objects
        if hasattr(signal, 'SerializeToString'):
            event = self._interpret_protobuf(signal)
            self._learn_pattern(signal, event)
            return event
        
        # Fallback: convert to string and interpret
        event = self._interpret_string(str(signal))
        return event
    
    def _interpret_dict(self, signal: Dict[str, Any]) -> Event:
        """Interpret a dictionary signal."""
        # Extract common fields
        timestamp = signal.get('timestamp', signal.get('time', time.time()))
        if isinstance(timestamp, str):
            try:
                from datetime import datetime
                timestamp = datetime.fromisoformat(timestamp).timestamp()
            except:
                timestamp = time.time()
        
        component = signal.get('component', signal.get('source', 'unknown'))
        event_type = signal.get('event_type', signal.get('type', 'state_change'))
        
        # Everything else goes in data
        reserved = {'timestamp', 'time', 'component', 'source', 'event_type', 'type'}
        data = {k: v for k, v in signal.items() if k not in reserved}
        
        return Event(
            timestamp=timestamp,
            component=component,
            event_type=event_type,
            data=data,
            source_signal=signal,
        )
    
    def _interpret_string(self, signal: str) -> Event:
        """Interpret a string signal (log line, message, etc.)."""
        signal = signal.strip()
        
        # Try each log pattern
        for pattern in self._log_patterns:
            match = pattern.match(signal)
            if match:
                groups = match.groups()
                if len(groups) >= 3:
                    timestamp_str, level_or_component, rest = groups[0], groups[1], groups[-1]
                    
                    # Parse timestamp
                    try:
                        from datetime import datetime
                        timestamp = datetime.fromisoformat(timestamp_str.replace(' ', 'T')).timestamp()
                    except:
                        timestamp = time.time()
                    
                    # Extract metrics from the rest
                    data = self._extract_metrics(rest)
                    data['raw_message'] = rest
                    
                    # Determine event type from keywords
                    event_type = self._infer_event_type(signal)
                    
                    return Event(
                        timestamp=timestamp,
                        component=level_or_component.lower(),
                        event_type=event_type,
                        data=data,
                        source_signal=signal,
                    )
        
        # Fallback: extract what we can with smarter component detection
        data = self._extract_metrics(signal)
        data['raw_message'] = signal
        
        # Infer component from content
        component = self._infer_component(signal)
        
        return Event(
            timestamp=time.time(),
            component=component,
            event_type=self._infer_event_type(signal),
            data=data,
            source_signal=signal,
        )
    
    def _interpret_list(self, signal: List[Any]) -> Event:
        """Interpret a list signal."""
        # Convert to dict with indices
        data = {f'item_{i}': v for i, v in enumerate(signal)}
        data['length'] = len(signal)
        
        # Check if it looks like numeric data
        if all(isinstance(x, (int, float)) for x in signal):
            data['mean'] = sum(signal) / len(signal) if signal else 0
            data['min'] = min(signal) if signal else 0
            data['max'] = max(signal) if signal else 0
        
        return Event(
            timestamp=time.time(),
            component='data',
            event_type='list_signal',
            data=data,
            source_signal=signal,
        )
    
    def _interpret_tensor(self, signal: Any) -> Event:
        """Interpret a tensor-like signal (PyTorch, NumPy, etc.)."""
        # Try to get numpy array
        try:
            if hasattr(signal, 'detach'):
                arr = signal.detach().cpu().numpy()
            elif hasattr(signal, 'numpy'):
                arr = signal.numpy()
            else:
                arr = signal
            
            data = {
                'shape': list(arr.shape) if hasattr(arr, 'shape') else [],
                'dtype': str(arr.dtype) if hasattr(arr, 'dtype') else 'unknown',
                'mean': float(arr.mean()) if hasattr(arr, 'mean') else 0,
                'std': float(arr.std()) if hasattr(arr, 'std') else 0,
                'min': float(arr.min()) if hasattr(arr, 'min') else 0,
                'max': float(arr.max()) if hasattr(arr, 'max') else 0,
            }
            
            # Check for NaN/Inf (common in gradient explosions)
            if hasattr(arr, 'isnan'):
                data['has_nan'] = bool(arr.isnan().any())
            if hasattr(arr, 'isinf'):
                data['has_inf'] = bool(arr.isinf().any())
            
        except Exception as e:
            data = {'error': str(e), 'type': str(type(signal))}
        
        return Event(
            timestamp=time.time(),
            component='tensor',
            event_type='tensor_signal',
            data=data,
            source_signal=None,  # Don't store tensor to save memory
        )
    
    def _interpret_protobuf(self, signal: Any) -> Event:
        """Interpret a protobuf-like signal."""
        try:
            # Try to convert to dict
            if hasattr(signal, 'DESCRIPTOR'):
                from google.protobuf.json_format import MessageToDict
                data = MessageToDict(signal)
            else:
                data = {'raw': str(signal)}
        except:
            data = {'raw': str(signal)}
        
        return Event(
            timestamp=time.time(),
            component='protobuf',
            event_type='protobuf_signal',
            data=data,
            source_signal=None,
        )
    
    def _extract_metrics(self, text: str) -> Dict[str, Any]:
        """Extract numeric metrics from text - STRICT, only real training metrics."""
        metrics = {}
        
        # nanoGPT format: "iter 0: loss=4.2176, time 46.76ms, mfu 0.62%"
        nano_match = re.search(r'iter\s+(\d+).*loss[=:]?\s*([\d.]+)', text, re.I)
        if nano_match:
            metrics['iter'] = int(nano_match.group(1))
            metrics['loss'] = float(nano_match.group(2))
        
        # Diffusers/tqdm format: "step_loss=0.1234" or "step_loss: 0.1234"
        step_loss_match = re.search(r'step_loss[=:]\s*([\d.e+-]+)', text, re.I)
        if step_loss_match:
            metrics['loss'] = float(step_loss_match.group(1))
        
        # train_loss format from accelerator.log
        train_loss_match = re.search(r'train_loss[=:]\s*([\d.e+-]+)', text, re.I)
        if train_loss_match:
            metrics['loss'] = float(train_loss_match.group(1))
        
        # tqdm progress format: "  5%|█         | 5/100 [00:30<09:30, step_loss=0.234, lr=1e-5]"
        tqdm_match = re.search(r'(\d+)%\|.*\|\s*(\d+)/(\d+)', text)
        if tqdm_match:
            metrics['progress_pct'] = int(tqdm_match.group(1))
            metrics['step'] = int(tqdm_match.group(2))
            metrics['total_steps'] = int(tqdm_match.group(3))
        
        # Generic loss patterns
        generic_loss = re.search(r'\bloss[=:]\s*([\d.e+-]+)', text, re.I)
        if generic_loss and 'loss' not in metrics:
            metrics['loss'] = float(generic_loss.group(1))
        
        # mfu extraction
        mfu_match = re.search(r'mfu\s*[=:]?\s*([\d.]+)%?', text, re.I)
        if mfu_match:
            metrics['mfu'] = float(mfu_match.group(1))
        
        # time extraction (ms)
        time_match = re.search(r'time\s*[=:]?\s*([\d.]+)\s*ms', text, re.I)
        if time_match:
            metrics['time_ms'] = float(time_match.group(1))
        
        # learning rate - multiple formats
        lr_match = re.search(r'\b(?:lr|learning_rate)\s*[=:]\s*([\d.e+-]+)', text, re.I)
        if lr_match:
            metrics['lr'] = float(lr_match.group(1))
        
        # epoch/step for other frameworks
        epoch_match = re.search(r'\bepoch\s*[=:]\s*(\d+)', text, re.I)
        if epoch_match:
            metrics['epoch'] = int(epoch_match.group(1))
        
        step_match = re.search(r'\bstep\s*[=:]\s*(\d+)', text, re.I)
        if step_match and 'step' not in metrics:
            metrics['step'] = int(step_match.group(1))
        
        # global_step from diffusers
        global_step_match = re.search(r'global_step[=:]\s*(\d+)', text, re.I)
        if global_step_match:
            metrics['step'] = int(global_step_match.group(1))
        
        return metrics
    
    def _infer_event_type(self, text: str) -> str:
        """Infer event type from text content."""
        text_lower = text.lower()
        
        # Training iteration logs (highest priority)
        if re.search(r'iter\s+\d+.*loss', text_lower):
            return 'training_step'
        if re.search(r'step\s+\d+.*loss', text_lower):
            return 'training_step'
        
        if any(kw in text_lower for kw in ['error', 'exception', 'failed', 'crash']):
            return 'error'
        if any(kw in text_lower for kw in ['warning', 'warn']):
            return 'warning'
        if any(kw in text_lower for kw in ['gradient', 'backward']):
            return 'training'
        if 'loss' in text_lower and 'val' in text_lower:
            return 'validation'
        if any(kw in text_lower for kw in ['inference', 'predict', 'forward']):
            return 'inference'
        if any(kw in text_lower for kw in ['epoch', 'step', 'iteration', 'iter']):
            return 'progress'
        if any(kw in text_lower for kw in ['nan', 'inf', 'explode', 'overflow']):
            return 'anomaly'
        if any(kw in text_lower for kw in ['save', 'checkpoint', 'load', 'saving']):
            return 'checkpoint'
        if any(kw in text_lower for kw in ['config', 'setting', 'parameter', 'device', 'gpu', 'cuda']):
            return 'config'
        if any(kw in text_lower for kw in ['initializ', 'loading model', 'compiling']):
            return 'init'
        
        return 'state_change'
    
    def _infer_component(self, text: str) -> str:
        """Infer component from text content - NO MORE 'unknown'."""
        text_lower = text.lower()
        
        # Training/optimizer related
        if any(kw in text_lower for kw in ['iter', 'step', 'epoch', 'batch']):
            return 'trainer'
        if any(kw in text_lower for kw in ['loss', 'backward', 'gradient']):
            return 'loss'
        if any(kw in text_lower for kw in ['optim', 'adam', 'sgd', 'lr', 'learning']):
            return 'optimizer'
        if any(kw in text_lower for kw in ['model', 'layer', 'param', 'weight']):
            return 'model'
        if any(kw in text_lower for kw in ['data', 'batch', 'loader', 'dataset']):
            return 'data'
        if any(kw in text_lower for kw in ['cuda', 'gpu', 'device', 'memory']):
            return 'device'
        if any(kw in text_lower for kw in ['checkpoint', 'save', 'load']):
            return 'checkpoint'
        if any(kw in text_lower for kw in ['config', 'setting', 'override']):
            return 'config'
        if any(kw in text_lower for kw in ['eval', 'valid', 'test']):
            return 'evaluator'
        if any(kw in text_lower for kw in ['token', 'vocab', 'embed']):
            return 'tokenizer'
        
        return 'system'  # Generic fallback, not "unknown"
    
    def _get_cache_key(self, signal: Any) -> str:
        """Generate a cache key for a signal's structure."""
        if isinstance(signal, dict):
            # Key based on dict keys
            return f"dict:{':'.join(sorted(signal.keys()))}"
        elif isinstance(signal, str):
            # Key based on first word
            first_word = signal.split()[0] if signal.split() else ''
            return f"str:{first_word[:20]}"
        else:
            return f"type:{type(signal).__name__}"
    
    def _learn_pattern(self, signal: Any, event: Event) -> None:
        """Learn a pattern from a successful interpretation."""
        cache_key = self._get_cache_key(signal)
        pattern = SignalPattern(
            pattern_type=type(signal).__name__,
            component=event.component,
            event_type=event.event_type,
            confidence=0.5,
            match_count=1,
        )
        self._interpretation_cache[cache_key] = pattern
        self._patterns.append(pattern)
    
    def _apply_pattern(self, signal: Any, pattern: SignalPattern) -> Event:
        """Apply a learned pattern to interpret a signal."""
        # Re-interpret with learned hints - use direct interpreters to avoid recursion
        if isinstance(signal, dict):
            event = self._interpret_dict(signal)
            # Apply learned component/type if more confident
            if pattern.confidence > 0.7:
                return Event(
                    timestamp=event.timestamp,
                    component=pattern.component,
                    event_type=pattern.event_type,
                    data=event.data,
                    source_signal=signal,
                )
            return event
        elif isinstance(signal, str):
            return self._interpret_string(signal)
        elif isinstance(signal, list):
            return self._interpret_list(signal)
        else:
            # Fallback: interpret as string without recursion
            return self._interpret_string(str(signal))
    
    @property
    def learned_patterns(self) -> List[SignalPattern]:
        """Get all learned signal patterns."""
        return sorted(self._patterns, key=lambda p: p.match_count, reverse=True)
    
    @property 
    def signal_count(self) -> int:
        """Total number of signals interpreted."""
        return self._signal_count
    
    def __repr__(self) -> str:
        return f"<SymbioticAdapter | {self._signal_count} signals, {len(self._patterns)} patterns>"