""" Cascade Core - Symbiotic Adapter. The heart of Cascade's system-agnostic design. The adapter uses Kleene fixed-point convergence to interpret ANY signal format and convert it to Events. "It doesn't hook into your system — it becomes part of it." """ import time import json import re from typing import Any, Dict, List, Optional, Callable, Type from dataclasses import dataclass from cascade.core.event import Event @dataclass class SignalPattern: """A learned pattern for interpreting signals.""" pattern_type: str # 'dict', 'string', 'tensor', 'protobuf', etc. component: str event_type: str extractor: Optional[Callable[[Any], Dict[str, Any]]] = None confidence: float = 0.0 match_count: int = 0 class SymbioticAdapter: """ Self-interpreting adapter that converges to any signal format. The adapter observes signals from the host system and learns how to interpret them through fixed-point iteration. It starts with naive interpretations and refines them until stable. This is the key to Cascade's system-agnostic design: - No framework-specific hooks required - No configuration needed - Feed it ANY signal format, it adapts Example: >>> adapter = SymbioticAdapter() >>> >>> # Feed it different signal formats >>> adapter.interpret({"loss": 0.5, "epoch": 10}) >>> adapter.interpret("2024-01-01 12:00:00 ERROR training failed") >>> adapter.interpret(torch.tensor([0.1, 0.2, 0.3])) >>> >>> # It learns patterns and gets better at interpretation >>> print(adapter.learned_patterns) """ def __init__(self): """Initialize the symbiotic adapter.""" self._patterns: List[SignalPattern] = [] self._signal_count = 0 self._interpretation_cache: Dict[str, SignalPattern] = {} # Built-in interpreters for common formats self._builtin_interpreters = { dict: self._interpret_dict, str: self._interpret_string, list: self._interpret_list, } # Regex patterns for log line parsing self._log_patterns = [ # ISO timestamp with level: "2024-01-01 12:00:00 ERROR message" re.compile(r'^(\d{4}-\d{2}-\d{2}[T\s]\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+(\w+)\s+(.*)$'), # Simple timestamp: "12:00:00.123 component message" re.compile(r'^(\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+(\w+)\s+(.*)$'), # Pipe-delimited: "timestamp|level|component|key:value" re.compile(r'^([^|]+)\|(\w+)\|(\w+)\|(.*)$'), ] # Metric extraction patterns - ONLY extract real training metrics # Be strict to avoid extracting garbage from config lines self._metric_patterns = [ # Standard training metrics with = or : re.compile(r'\b(loss|val_loss|train_loss|accuracy|acc|val_acc|lr|learning_rate|epoch|step|iter|iteration|mfu|tokens_per_sec|samples_per_sec|grad_norm|perplexity|ppl)[=:]\s*([+-]?\d+\.?\d*(?:e[+-]?\d+)?)', re.I), # "iter X: loss=Y" format from nanoGPT re.compile(r'iter\s+(\d+).*loss[=:]?\s*([+-]?\d+\.?\d*)', re.I), # "step X loss Y" format re.compile(r'step\s+(\d+).*loss\s*[=:]?\s*([+-]?\d+\.?\d*)', re.I), ] def interpret(self, signal: Any) -> Event: """ Interpret any signal into a Cascade Event. Uses Kleene fixed-point iteration to converge on the best interpretation. Args: signal: Any signal from the host system Returns: Event: The interpreted event """ self._signal_count += 1 # Get signal type signal_type = type(signal) # Try cached pattern first cache_key = self._get_cache_key(signal) if cache_key in self._interpretation_cache: pattern = self._interpretation_cache[cache_key] pattern.match_count += 1 return self._apply_pattern(signal, pattern) # Try built-in interpreter if signal_type in self._builtin_interpreters: event = self._builtin_interpreters[signal_type](signal) self._learn_pattern(signal, event) return event # Try tensor-like objects (duck typing) if hasattr(signal, 'numpy') or hasattr(signal, 'detach'): event = self._interpret_tensor(signal) self._learn_pattern(signal, event) return event # Try protobuf-like objects if hasattr(signal, 'SerializeToString'): event = self._interpret_protobuf(signal) self._learn_pattern(signal, event) return event # Fallback: convert to string and interpret event = self._interpret_string(str(signal)) return event def _interpret_dict(self, signal: Dict[str, Any]) -> Event: """Interpret a dictionary signal.""" # Extract common fields timestamp = signal.get('timestamp', signal.get('time', time.time())) if isinstance(timestamp, str): try: from datetime import datetime timestamp = datetime.fromisoformat(timestamp).timestamp() except: timestamp = time.time() component = signal.get('component', signal.get('source', 'unknown')) event_type = signal.get('event_type', signal.get('type', 'state_change')) # Everything else goes in data reserved = {'timestamp', 'time', 'component', 'source', 'event_type', 'type'} data = {k: v for k, v in signal.items() if k not in reserved} return Event( timestamp=timestamp, component=component, event_type=event_type, data=data, source_signal=signal, ) def _interpret_string(self, signal: str) -> Event: """Interpret a string signal (log line, message, etc.).""" signal = signal.strip() # Try each log pattern for pattern in self._log_patterns: match = pattern.match(signal) if match: groups = match.groups() if len(groups) >= 3: timestamp_str, level_or_component, rest = groups[0], groups[1], groups[-1] # Parse timestamp try: from datetime import datetime timestamp = datetime.fromisoformat(timestamp_str.replace(' ', 'T')).timestamp() except: timestamp = time.time() # Extract metrics from the rest data = self._extract_metrics(rest) data['raw_message'] = rest # Determine event type from keywords event_type = self._infer_event_type(signal) return Event( timestamp=timestamp, component=level_or_component.lower(), event_type=event_type, data=data, source_signal=signal, ) # Fallback: extract what we can with smarter component detection data = self._extract_metrics(signal) data['raw_message'] = signal # Infer component from content component = self._infer_component(signal) return Event( timestamp=time.time(), component=component, event_type=self._infer_event_type(signal), data=data, source_signal=signal, ) def _interpret_list(self, signal: List[Any]) -> Event: """Interpret a list signal.""" # Convert to dict with indices data = {f'item_{i}': v for i, v in enumerate(signal)} data['length'] = len(signal) # Check if it looks like numeric data if all(isinstance(x, (int, float)) for x in signal): data['mean'] = sum(signal) / len(signal) if signal else 0 data['min'] = min(signal) if signal else 0 data['max'] = max(signal) if signal else 0 return Event( timestamp=time.time(), component='data', event_type='list_signal', data=data, source_signal=signal, ) def _interpret_tensor(self, signal: Any) -> Event: """Interpret a tensor-like signal (PyTorch, NumPy, etc.).""" # Try to get numpy array try: if hasattr(signal, 'detach'): arr = signal.detach().cpu().numpy() elif hasattr(signal, 'numpy'): arr = signal.numpy() else: arr = signal data = { 'shape': list(arr.shape) if hasattr(arr, 'shape') else [], 'dtype': str(arr.dtype) if hasattr(arr, 'dtype') else 'unknown', 'mean': float(arr.mean()) if hasattr(arr, 'mean') else 0, 'std': float(arr.std()) if hasattr(arr, 'std') else 0, 'min': float(arr.min()) if hasattr(arr, 'min') else 0, 'max': float(arr.max()) if hasattr(arr, 'max') else 0, } # Check for NaN/Inf (common in gradient explosions) if hasattr(arr, 'isnan'): data['has_nan'] = bool(arr.isnan().any()) if hasattr(arr, 'isinf'): data['has_inf'] = bool(arr.isinf().any()) except Exception as e: data = {'error': str(e), 'type': str(type(signal))} return Event( timestamp=time.time(), component='tensor', event_type='tensor_signal', data=data, source_signal=None, # Don't store tensor to save memory ) def _interpret_protobuf(self, signal: Any) -> Event: """Interpret a protobuf-like signal.""" try: # Try to convert to dict if hasattr(signal, 'DESCRIPTOR'): from google.protobuf.json_format import MessageToDict data = MessageToDict(signal) else: data = {'raw': str(signal)} except: data = {'raw': str(signal)} return Event( timestamp=time.time(), component='protobuf', event_type='protobuf_signal', data=data, source_signal=None, ) def _extract_metrics(self, text: str) -> Dict[str, Any]: """Extract numeric metrics from text - STRICT, only real training metrics.""" metrics = {} # nanoGPT format: "iter 0: loss=4.2176, time 46.76ms, mfu 0.62%" nano_match = re.search(r'iter\s+(\d+).*loss[=:]?\s*([\d.]+)', text, re.I) if nano_match: metrics['iter'] = int(nano_match.group(1)) metrics['loss'] = float(nano_match.group(2)) # Diffusers/tqdm format: "step_loss=0.1234" or "step_loss: 0.1234" step_loss_match = re.search(r'step_loss[=:]\s*([\d.e+-]+)', text, re.I) if step_loss_match: metrics['loss'] = float(step_loss_match.group(1)) # train_loss format from accelerator.log train_loss_match = re.search(r'train_loss[=:]\s*([\d.e+-]+)', text, re.I) if train_loss_match: metrics['loss'] = float(train_loss_match.group(1)) # tqdm progress format: " 5%|█ | 5/100 [00:30<09:30, step_loss=0.234, lr=1e-5]" tqdm_match = re.search(r'(\d+)%\|.*\|\s*(\d+)/(\d+)', text) if tqdm_match: metrics['progress_pct'] = int(tqdm_match.group(1)) metrics['step'] = int(tqdm_match.group(2)) metrics['total_steps'] = int(tqdm_match.group(3)) # Generic loss patterns generic_loss = re.search(r'\bloss[=:]\s*([\d.e+-]+)', text, re.I) if generic_loss and 'loss' not in metrics: metrics['loss'] = float(generic_loss.group(1)) # mfu extraction mfu_match = re.search(r'mfu\s*[=:]?\s*([\d.]+)%?', text, re.I) if mfu_match: metrics['mfu'] = float(mfu_match.group(1)) # time extraction (ms) time_match = re.search(r'time\s*[=:]?\s*([\d.]+)\s*ms', text, re.I) if time_match: metrics['time_ms'] = float(time_match.group(1)) # learning rate - multiple formats lr_match = re.search(r'\b(?:lr|learning_rate)\s*[=:]\s*([\d.e+-]+)', text, re.I) if lr_match: metrics['lr'] = float(lr_match.group(1)) # epoch/step for other frameworks epoch_match = re.search(r'\bepoch\s*[=:]\s*(\d+)', text, re.I) if epoch_match: metrics['epoch'] = int(epoch_match.group(1)) step_match = re.search(r'\bstep\s*[=:]\s*(\d+)', text, re.I) if step_match and 'step' not in metrics: metrics['step'] = int(step_match.group(1)) # global_step from diffusers global_step_match = re.search(r'global_step[=:]\s*(\d+)', text, re.I) if global_step_match: metrics['step'] = int(global_step_match.group(1)) return metrics def _infer_event_type(self, text: str) -> str: """Infer event type from text content.""" text_lower = text.lower() # Training iteration logs (highest priority) if re.search(r'iter\s+\d+.*loss', text_lower): return 'training_step' if re.search(r'step\s+\d+.*loss', text_lower): return 'training_step' if any(kw in text_lower for kw in ['error', 'exception', 'failed', 'crash']): return 'error' if any(kw in text_lower for kw in ['warning', 'warn']): return 'warning' if any(kw in text_lower for kw in ['gradient', 'backward']): return 'training' if 'loss' in text_lower and 'val' in text_lower: return 'validation' if any(kw in text_lower for kw in ['inference', 'predict', 'forward']): return 'inference' if any(kw in text_lower for kw in ['epoch', 'step', 'iteration', 'iter']): return 'progress' if any(kw in text_lower for kw in ['nan', 'inf', 'explode', 'overflow']): return 'anomaly' if any(kw in text_lower for kw in ['save', 'checkpoint', 'load', 'saving']): return 'checkpoint' if any(kw in text_lower for kw in ['config', 'setting', 'parameter', 'device', 'gpu', 'cuda']): return 'config' if any(kw in text_lower for kw in ['initializ', 'loading model', 'compiling']): return 'init' return 'state_change' def _infer_component(self, text: str) -> str: """Infer component from text content - NO MORE 'unknown'.""" text_lower = text.lower() # Training/optimizer related if any(kw in text_lower for kw in ['iter', 'step', 'epoch', 'batch']): return 'trainer' if any(kw in text_lower for kw in ['loss', 'backward', 'gradient']): return 'loss' if any(kw in text_lower for kw in ['optim', 'adam', 'sgd', 'lr', 'learning']): return 'optimizer' if any(kw in text_lower for kw in ['model', 'layer', 'param', 'weight']): return 'model' if any(kw in text_lower for kw in ['data', 'batch', 'loader', 'dataset']): return 'data' if any(kw in text_lower for kw in ['cuda', 'gpu', 'device', 'memory']): return 'device' if any(kw in text_lower for kw in ['checkpoint', 'save', 'load']): return 'checkpoint' if any(kw in text_lower for kw in ['config', 'setting', 'override']): return 'config' if any(kw in text_lower for kw in ['eval', 'valid', 'test']): return 'evaluator' if any(kw in text_lower for kw in ['token', 'vocab', 'embed']): return 'tokenizer' return 'system' # Generic fallback, not "unknown" def _get_cache_key(self, signal: Any) -> str: """Generate a cache key for a signal's structure.""" if isinstance(signal, dict): # Key based on dict keys return f"dict:{':'.join(sorted(signal.keys()))}" elif isinstance(signal, str): # Key based on first word first_word = signal.split()[0] if signal.split() else '' return f"str:{first_word[:20]}" else: return f"type:{type(signal).__name__}" def _learn_pattern(self, signal: Any, event: Event) -> None: """Learn a pattern from a successful interpretation.""" cache_key = self._get_cache_key(signal) pattern = SignalPattern( pattern_type=type(signal).__name__, component=event.component, event_type=event.event_type, confidence=0.5, match_count=1, ) self._interpretation_cache[cache_key] = pattern self._patterns.append(pattern) def _apply_pattern(self, signal: Any, pattern: SignalPattern) -> Event: """Apply a learned pattern to interpret a signal.""" # Re-interpret with learned hints - use direct interpreters to avoid recursion if isinstance(signal, dict): event = self._interpret_dict(signal) # Apply learned component/type if more confident if pattern.confidence > 0.7: return Event( timestamp=event.timestamp, component=pattern.component, event_type=pattern.event_type, data=event.data, source_signal=signal, ) return event elif isinstance(signal, str): return self._interpret_string(signal) elif isinstance(signal, list): return self._interpret_list(signal) else: # Fallback: interpret as string without recursion return self._interpret_string(str(signal)) @property def learned_patterns(self) -> List[SignalPattern]: """Get all learned signal patterns.""" return sorted(self._patterns, key=lambda p: p.match_count, reverse=True) @property def signal_count(self) -> int: """Total number of signals interpreted.""" return self._signal_count def __repr__(self) -> str: return f""