diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..ef26598d9c1da5f671da513463a6c76e13b3df7a --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,31 @@ +name: Publish to PyPI + +on: + push: + tags: + - 'v*' + +jobs: + publish: + runs-on: ubuntu-latest + permissions: + id-token: write # For trusted publishing (optional) + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install build tools + run: pip install build + + - name: Build package + run: python -m build + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bc9cd26d1ea7def15f658af69ecc25b5996a4bc6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +# Byte-compiled +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +dist/ +build/ +*.egg-info/ +*.egg +*.whl + +# Virtual environments +venv/ +.venv/ +env/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# Logs +*.log +logs/ + +# OS +.DS_Store +Thumbs.db diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..13e314aaddabb8cbed8502f05014e30e34f5172c --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024-2026 Jeff Towers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9afea6f6c33d5221ae3edb445bcaedf418ac9fea --- /dev/null +++ b/README.md @@ -0,0 +1,70 @@ +# Cascade Lattice + +**Universal AI provenance layer — cryptographic receipts for every call, with HOLD inference halt protocol** + +[![PyPI version](https://badge.fury.io/py/cascade-lattice.svg)](https://pypi.org/project/cascade-lattice/) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + +## Installation + +```bash +pip install cascade-lattice +``` + +With optional dependencies: +```bash +pip install cascade-lattice[torch] # PyTorch integration +pip install cascade-lattice[all] # All integrations +``` + +## Quick Start + +```python +from cascade import Monitor + +# Create a monitor for your component +monitor = Monitor("training_loop") + +# Observe events (parses logs, extracts metrics) +event = monitor.observe("Epoch 5: loss=0.0234, accuracy=0.9812") +print(event.data) # {'loss': 0.0234, 'accuracy': 0.9812, ...} + +# Get metrics summary +print(monitor.metrics.summary()) +``` + +## Features + +- **Universal Observation** — Monitor training, inference, system logs, API calls +- **Cryptographic Receipts** — Every observation gets a verifiable hash chain +- **HOLD Protocol** — Inference halt capability for safety-critical applications +- **Tape Storage** — JSONL event streams for replay and analysis +- **Provider Patches** — Drop-in monitoring for OpenAI, Anthropic, LiteLLM, Ollama + +## CLI Usage + +```bash +cascade --help # Show all commands +cascade stats # Lattice statistics +cascade list -n 20 # Recent observations +cascade watch # Live observation feed +cascade fingerprint model/ # Fingerprint a model +cascade pii scan.log # Scan for PII +``` + +## Tape Utilities + +```python +from cascade.viz import load_tape_file, find_latest_tape, list_tape_files + +# Find and load tape files +latest = find_latest_tape("./logs") +events = load_tape_file(latest) + +for event in events: + print(event['event']['event_type'], event['event']['data']) +``` + +## License + +MIT diff --git a/cascade/__init__.py b/cascade/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c7126c6c854f27ac4c379667176b584a743341 --- /dev/null +++ b/cascade/__init__.py @@ -0,0 +1,290 @@ +""" +╔═══════════════════════════════════════════════════════════════════════════════╗ +║ ║ +║ ██████╗ █████╗ ███████╗ ██████╗ █████╗ ██████╗ ███████╗ ║ +║ ██╔════╝██╔══██╗██╔════╝██╔════╝██╔══██╗██╔══██╗██╔════╝ ║ +║ ██║ ███████║███████╗██║ ███████║██║ ██║█████╗ ║ +║ ██║ ██╔══██║╚════██║██║ ██╔══██║██║ ██║██╔══╝ ║ +║ ╚██████╗██║ ██║███████║╚██████╗██║ ██║██████╔╝███████╗ ║ +║ ╚═════╝╚═╝ ╚═╝╚══════╝ ╚═════╝╚═╝ ╚═╝╚═════╝ ╚══════╝ ║ +║ ║ +║ Symbiotic Causation Monitoring for Neural Networks ║ +║ ║ +║ "even still, i grow, and yet, I grow still" ║ +║ ║ +╚═══════════════════════════════════════════════════════════════════════════════╝ + +Cascade is a self-interpreting causation monitor that symbiotically adapts to +any system architecture through Kleene fixed-point convergence. + +Feed it ANY signal format. It learns your system's patterns. It traces cause +and effect bidirectionally through time. It predicts cascading failures before +they complete. + +Quick Start: + >>> import cascade + >>> monitor = cascade.Monitor() + >>> monitor.observe({"loss": 0.5, "epoch": 10}) + >>> monitor.observe("ERROR: gradient exploded at layer 5") + >>> + >>> # What caused this? + >>> monitor.trace_backwards("gradient_explosion") + >>> + >>> # What will this cause? + >>> monitor.trace_forwards("learning_rate_spike") +""" + +__version__ = "0.5.4" +__author__ = "Cascade Team" +__license__ = "MIT" + +from cascade.core.event import Event, CausationLink +from cascade.core.graph import CausationGraph +from cascade.core.adapter import SymbioticAdapter +from cascade.analysis.tracer import Tracer +from cascade.analysis.metrics import MetricsEngine + +# Primary API +class Monitor: + """ + The main entry point for Cascade monitoring. + + A symbiotic observer that acclimate to any system architecture. + Feed it signals in any format — it adapts and builds a causation graph. + + Example: + >>> monitor = cascade.Monitor() + >>> + >>> # Feed it anything - dicts, strings, tensors, whatever + >>> monitor.observe({"loss": 0.5, "epoch": 10}) + >>> monitor.observe("2024-01-01 12:00:00 INFO training started") + >>> monitor.observe(torch.tensor([0.1, 0.2, 0.3])) + >>> + >>> # Trace causation backwards (what caused this?) + >>> causes = monitor.trace_backwards(event_id) + >>> + >>> # Trace causation forwards (what will this cause?) + >>> effects = monitor.trace_forwards(event_id) + >>> + >>> # Get the full causation graph + >>> graph = monitor.graph + """ + + def __init__(self, name: str = "default"): + """ + Initialize a new Cascade monitor. + + Args: + name: Optional name for this monitor instance + """ + self.name = name + self.adapter = SymbioticAdapter() + self.graph = CausationGraph() + self.tracer = Tracer(self.graph) + self.metrics = MetricsEngine(self.graph) + self._event_count = 0 + + def observe(self, signal) -> Event: + """ + Observe a signal from the host system. + + The signal can be in ANY format: + - dict: {"loss": 0.5, "epoch": 10} + - str: "ERROR: gradient exploded" + - tensor: torch.tensor([...]) + - protobuf, JSON, log line, etc. + + Cascade will automatically adapt to your signal format. + + Args: + signal: Any signal from the host system + + Returns: + Event: The interpreted event added to the causation graph + """ + event = self.adapter.interpret(signal) + self.graph.add_event(event) + self.metrics.ingest(event) + self._event_count += 1 + return event + + def trace_backwards(self, event_id: str, max_depth: int = 10): + """ + Trace causation backwards: what caused this event? + + Args: + event_id: ID of the event to trace from + max_depth: Maximum depth to trace (default: 10) + + Returns: + List of CausationChain objects showing the causal history + """ + return self.tracer.trace_backwards(event_id, max_depth) + + def trace_forwards(self, event_id: str, max_depth: int = 10): + """ + Trace causation forwards: what did this event cause? + + Args: + event_id: ID of the event to trace from + max_depth: Maximum depth to trace (default: 10) + + Returns: + List of CausationChain objects showing the effects + """ + return self.tracer.trace_forwards(event_id, max_depth) + + def find_root_causes(self, event_id: str): + """ + Find the ultimate root causes of an event. + + Goes all the way back to find the origin points. + + Args: + event_id: ID of the event to analyze + + Returns: + List of root cause events with their causal chains + """ + return self.tracer.find_root_causes(event_id) + + def analyze_impact(self, event_id: str, max_depth: int = 20): + """ + Analyze the downstream impact of an event. + + Traces forward to find everything this event set in motion. + + Args: + event_id: ID of the event to analyze + max_depth: Maximum depth to search + + Returns: + ImpactAnalysis with effects and severity score + """ + return self.tracer.analyze_impact(event_id, max_depth) + + def predict_cascade(self, event_id: str): + """ + Predict the likely future cascade from this event. + + Uses learned patterns to forecast effects before they happen. + + Args: + event_id: ID of the event to predict from + + Returns: + CascadePrediction with risk scores and intervention points + """ + return self.tracer.predict_cascade(event_id) + + def __repr__(self): + return f"" + + +# Convenience function for quick setup +def observe() -> Monitor: + """ + Create a new Cascade monitor ready for observation. + + This is the simplest way to get started: + + >>> import cascade + >>> monitor = cascade.observe() + >>> monitor.observe({"loss": 0.5}) + + Returns: + Monitor: A new monitor instance + """ + return Monitor() + + +# Tape utilities for event storage +from cascade.viz.tape import ( + load_tape_file, + find_latest_tape, + list_tape_files, + PlaybackBuffer, +) + +# SDK - Universal AI Observation Layer +from cascade.sdk import init, observe as sdk_observe, shutdown + +# Store - Simple observe/query with HuggingFace sync +from cascade.store import ( + observe as store_observe, + query as store_query, + get as store_get, + stats as store_stats, + sync_all, + pull_from_hf, + Receipt, + # Discovery - find other users' lattices + discover_models, + discover_datasets, + discover_live, + dataset_info, +) + +# Convenience aliases +auto_observe = init # cascade.auto_observe() is clearer for some users + +# HOLD - Inference-Level Halt Protocol +from cascade import hold as hold_module +from cascade.hold import ( + Hold, + HoldPoint, + HoldResolution, + HoldState, + HoldAwareMixin, + CausationHold, + InferenceStep, + HoldSession, + ArcadeFeedback, +) + + +__all__ = [ + # SDK - Primary Interface + "init", + "auto_observe", + "shutdown", + # Store - HuggingFace-backed storage + "store_observe", + "store_query", + "store_get", + "store_stats", + "sync_all", + "pull_from_hf", + "Receipt", + # Discovery + "discover_models", + "discover_datasets", + "discover_live", + "dataset_info", + # Monitor (causation tracking) + "Monitor", + "observe", + "Event", + "CausationLink", + "CausationGraph", + "SymbioticAdapter", + "Tracer", + "MetricsEngine", + # Tape playback + "load_tape_file", + "find_latest_tape", + "list_tape_files", + "PlaybackBuffer", + # HOLD - Inference Halt Protocol + "Hold", + "HoldPoint", + "HoldResolution", + "HoldState", + "HoldAwareMixin", + "CausationHold", + "InferenceStep", + "HoldSession", + "ArcadeFeedback", + "hold_module", + "__version__", +] diff --git a/cascade/analysis/__init__.py b/cascade/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cdac516e57fd68c2faaf3b49db85c106e7d48f8 --- /dev/null +++ b/cascade/analysis/__init__.py @@ -0,0 +1,37 @@ +"""Cascade Analysis module - tracing, prediction, and intervention.""" + +from cascade.analysis.tracer import ( + Tracer, + RootCauseAnalysis, + ImpactAnalysis, + CascadePrediction, +) +from cascade.analysis.metrics import ( + MetricsEngine, + MetricSeries, + MetricCategory, + MetricHealthSpec, + Anomaly, + Correlation, + ThresholdCrossing, + classify_metric, + METRIC_TAXONOMY, + HEALTH_SPECS, +) + +__all__ = [ + "Tracer", + "RootCauseAnalysis", + "ImpactAnalysis", + "CascadePrediction", + "MetricsEngine", + "MetricSeries", + "MetricCategory", + "MetricHealthSpec", + "Anomaly", + "Correlation", + "ThresholdCrossing", + "classify_metric", + "METRIC_TAXONOMY", + "HEALTH_SPECS", +] diff --git a/cascade/analysis/metrics.py b/cascade/analysis/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..116bad0d95b891874973c4b34cbb37155f66e1f9 --- /dev/null +++ b/cascade/analysis/metrics.py @@ -0,0 +1,1168 @@ +""" +Cascade Analysis - Metrics Engine. + +The quantification layer. Extracts, tracks, and correlates numeric data +from the event stream. Provides the WHAT with enough depth that the WHY +becomes self-evident to the observer. + +This module does NOT interpret or explain. It quantifies. + +Industry-Standard Neural Network Observability Taxonomy: +========================================================= + +CATEGORY 1: TRAINING_DYNAMICS + Core training loop metrics - loss, accuracy, learning rate, throughput + +CATEGORY 2: GRADIENT_HEALTH + Gradient flow diagnostics - norms, clipping, vanishing/exploding + +CATEGORY 3: WEIGHT_DYNAMICS + Parameter evolution - norms, update ratios, dead neurons + +CATEGORY 4: ACTIVATION_FLOW + Forward pass health - magnitudes, saturation, dead ReLUs + +CATEGORY 5: ATTENTION_MECHANICS + Transformer-specific - entropy, sparsity, head importance + +CATEGORY 6: MEMORY_COMPUTE + Resource utilization - GPU/CPU memory, MFU, throughput + +CATEGORY 7: OPTIMIZATION_STATE + Optimizer internals - Adam moments, momentum, weight decay + +CATEGORY 8: CONVERGENCE_SIGNALS + Training health indicators - plateau, overfitting, noise scale + +CATEGORY 9: DATA_PIPELINE + Data loading metrics - batch time, queue depth, prefetch + +CATEGORY 10: REGULARIZATION + Regularization effects - dropout, batch norm, layer norm stats +""" + +from typing import Dict, List, Any, Optional, Tuple, Set +from dataclasses import dataclass, field +from collections import defaultdict +from enum import Enum, auto +import math +import re + +from cascade.core.event import Event +from cascade.core.graph import CausationGraph + + +# ============================================================================= +# METRIC CATEGORY TAXONOMY +# ============================================================================= + +class MetricCategory(Enum): + """Industry-standard neural network metric categories.""" + TRAINING_DYNAMICS = auto() # Loss, accuracy, LR, throughput + GRADIENT_HEALTH = auto() # Grad norms, clipping, flow + WEIGHT_DYNAMICS = auto() # Weight norms, updates, dead neurons + ACTIVATION_FLOW = auto() # Activation stats, saturation + ATTENTION_MECHANICS = auto() # Attention entropy, sparsity, heads + MEMORY_COMPUTE = auto() # GPU/CPU mem, MFU, FLOPS + OPTIMIZATION_STATE = auto() # Adam moments, momentum, decay + CONVERGENCE_SIGNALS = auto() # Plateau, overfit, noise scale + DATA_PIPELINE = auto() # Batch time, queue, prefetch + REGULARIZATION = auto() # Dropout, norm layer stats + SYSTEM = auto() # Iteration, epoch, timestamps + UNKNOWN = auto() # Uncategorized metrics + + +# Comprehensive metric-to-category mapping +# This is the "knowledge base" of neural network metric taxonomy +METRIC_TAXONOMY: Dict[str, MetricCategory] = { + # TRAINING_DYNAMICS + "loss": MetricCategory.TRAINING_DYNAMICS, + "train_loss": MetricCategory.TRAINING_DYNAMICS, + "val_loss": MetricCategory.TRAINING_DYNAMICS, + "test_loss": MetricCategory.TRAINING_DYNAMICS, + "eval_loss": MetricCategory.TRAINING_DYNAMICS, + "nll_loss": MetricCategory.TRAINING_DYNAMICS, + "ce_loss": MetricCategory.TRAINING_DYNAMICS, + "cross_entropy": MetricCategory.TRAINING_DYNAMICS, + "mse_loss": MetricCategory.TRAINING_DYNAMICS, + "mae_loss": MetricCategory.TRAINING_DYNAMICS, + "perplexity": MetricCategory.TRAINING_DYNAMICS, + "ppl": MetricCategory.TRAINING_DYNAMICS, + "accuracy": MetricCategory.TRAINING_DYNAMICS, + "acc": MetricCategory.TRAINING_DYNAMICS, + "top1_acc": MetricCategory.TRAINING_DYNAMICS, + "top5_acc": MetricCategory.TRAINING_DYNAMICS, + "precision": MetricCategory.TRAINING_DYNAMICS, + "recall": MetricCategory.TRAINING_DYNAMICS, + "f1": MetricCategory.TRAINING_DYNAMICS, + "f1_score": MetricCategory.TRAINING_DYNAMICS, + "auc": MetricCategory.TRAINING_DYNAMICS, + "auroc": MetricCategory.TRAINING_DYNAMICS, + "bleu": MetricCategory.TRAINING_DYNAMICS, + "rouge": MetricCategory.TRAINING_DYNAMICS, + "lr": MetricCategory.TRAINING_DYNAMICS, + "learning_rate": MetricCategory.TRAINING_DYNAMICS, + "samples_per_sec": MetricCategory.TRAINING_DYNAMICS, + "tokens_per_sec": MetricCategory.TRAINING_DYNAMICS, + "throughput": MetricCategory.TRAINING_DYNAMICS, + "steps_per_sec": MetricCategory.TRAINING_DYNAMICS, + + # GRADIENT_HEALTH + "grad_norm": MetricCategory.GRADIENT_HEALTH, + "gradient_norm": MetricCategory.GRADIENT_HEALTH, + "global_grad_norm": MetricCategory.GRADIENT_HEALTH, + "grad_norm_clipped": MetricCategory.GRADIENT_HEALTH, + "grad_clip_rate": MetricCategory.GRADIENT_HEALTH, + "grad_scale": MetricCategory.GRADIENT_HEALTH, + "grad_mean": MetricCategory.GRADIENT_HEALTH, + "grad_std": MetricCategory.GRADIENT_HEALTH, + "grad_max": MetricCategory.GRADIENT_HEALTH, + "grad_min": MetricCategory.GRADIENT_HEALTH, + "grad_sparsity": MetricCategory.GRADIENT_HEALTH, + "vanishing_grad": MetricCategory.GRADIENT_HEALTH, + "exploding_grad": MetricCategory.GRADIENT_HEALTH, + + # WEIGHT_DYNAMICS + "weight_norm": MetricCategory.WEIGHT_DYNAMICS, + "param_norm": MetricCategory.WEIGHT_DYNAMICS, + "weight_mean": MetricCategory.WEIGHT_DYNAMICS, + "weight_std": MetricCategory.WEIGHT_DYNAMICS, + "update_ratio": MetricCategory.WEIGHT_DYNAMICS, + "weight_update": MetricCategory.WEIGHT_DYNAMICS, + "dead_neurons": MetricCategory.WEIGHT_DYNAMICS, + "dead_neuron_pct": MetricCategory.WEIGHT_DYNAMICS, + "param_count": MetricCategory.WEIGHT_DYNAMICS, + "num_params": MetricCategory.WEIGHT_DYNAMICS, + "trainable_params": MetricCategory.WEIGHT_DYNAMICS, + + # ACTIVATION_FLOW + "activation_mean": MetricCategory.ACTIVATION_FLOW, + "activation_std": MetricCategory.ACTIVATION_FLOW, + "activation_norm": MetricCategory.ACTIVATION_FLOW, + "activation_max": MetricCategory.ACTIVATION_FLOW, + "saturation": MetricCategory.ACTIVATION_FLOW, + "saturation_pct": MetricCategory.ACTIVATION_FLOW, + "dead_relu": MetricCategory.ACTIVATION_FLOW, + "dead_relu_pct": MetricCategory.ACTIVATION_FLOW, + "activation_sparsity": MetricCategory.ACTIVATION_FLOW, + # Generic activation stats from layer hooks + "mean": MetricCategory.ACTIVATION_FLOW, + "std": MetricCategory.ACTIVATION_FLOW, + "min": MetricCategory.ACTIVATION_FLOW, + "max": MetricCategory.ACTIVATION_FLOW, + "sparsity": MetricCategory.ACTIVATION_FLOW, + "layer_idx": MetricCategory.SYSTEM, + + # ATTENTION_MECHANICS + "attention_entropy": MetricCategory.ATTENTION_MECHANICS, + "attn_entropy": MetricCategory.ATTENTION_MECHANICS, + "attention_sparsity": MetricCategory.ATTENTION_MECHANICS, + "head_importance": MetricCategory.ATTENTION_MECHANICS, + "attention_weight_norm": MetricCategory.ATTENTION_MECHANICS, + "position_bias": MetricCategory.ATTENTION_MECHANICS, + "attention_score_mean": MetricCategory.ATTENTION_MECHANICS, + "attention_score_std": MetricCategory.ATTENTION_MECHANICS, + + # MEMORY_COMPUTE + "gpu_memory": MetricCategory.MEMORY_COMPUTE, + "gpu_mem": MetricCategory.MEMORY_COMPUTE, + "gpu_memory_allocated": MetricCategory.MEMORY_COMPUTE, + "gpu_memory_cached": MetricCategory.MEMORY_COMPUTE, + "gpu_memory_peak": MetricCategory.MEMORY_COMPUTE, + "cpu_memory": MetricCategory.MEMORY_COMPUTE, + "memory_usage": MetricCategory.MEMORY_COMPUTE, + "mfu": MetricCategory.MEMORY_COMPUTE, + "model_flops_utilization": MetricCategory.MEMORY_COMPUTE, + "flops": MetricCategory.MEMORY_COMPUTE, + "tflops": MetricCategory.MEMORY_COMPUTE, + "gpu_utilization": MetricCategory.MEMORY_COMPUTE, + "gpu_util": MetricCategory.MEMORY_COMPUTE, + + # OPTIMIZATION_STATE + "adam_m_norm": MetricCategory.OPTIMIZATION_STATE, + "adam_v_norm": MetricCategory.OPTIMIZATION_STATE, + "momentum": MetricCategory.OPTIMIZATION_STATE, + "beta1": MetricCategory.OPTIMIZATION_STATE, + "beta2": MetricCategory.OPTIMIZATION_STATE, + "weight_decay": MetricCategory.OPTIMIZATION_STATE, + "effective_weight_decay": MetricCategory.OPTIMIZATION_STATE, + "warmup_progress": MetricCategory.OPTIMIZATION_STATE, + "lr_schedule_progress": MetricCategory.OPTIMIZATION_STATE, + + # CONVERGENCE_SIGNALS + "train_val_gap": MetricCategory.CONVERGENCE_SIGNALS, + "overfit_ratio": MetricCategory.CONVERGENCE_SIGNALS, + "loss_plateau": MetricCategory.CONVERGENCE_SIGNALS, + "gradient_noise_scale": MetricCategory.CONVERGENCE_SIGNALS, + "critical_batch_size": MetricCategory.CONVERGENCE_SIGNALS, + "effective_batch_size": MetricCategory.CONVERGENCE_SIGNALS, + "early_stop_score": MetricCategory.CONVERGENCE_SIGNALS, + "best_val_loss": MetricCategory.CONVERGENCE_SIGNALS, + "improvement_rate": MetricCategory.CONVERGENCE_SIGNALS, + + # DATA_PIPELINE + "data_time": MetricCategory.DATA_PIPELINE, + "batch_time": MetricCategory.DATA_PIPELINE, + "load_time": MetricCategory.DATA_PIPELINE, + "preprocessing_time": MetricCategory.DATA_PIPELINE, + "augmentation_time": MetricCategory.DATA_PIPELINE, + "queue_depth": MetricCategory.DATA_PIPELINE, + "prefetch_factor": MetricCategory.DATA_PIPELINE, + "num_workers": MetricCategory.DATA_PIPELINE, + + # REGULARIZATION + "dropout_rate": MetricCategory.REGULARIZATION, + "dropout": MetricCategory.REGULARIZATION, + "bn_mean": MetricCategory.REGULARIZATION, + "bn_var": MetricCategory.REGULARIZATION, + "bn_running_mean": MetricCategory.REGULARIZATION, + "bn_running_var": MetricCategory.REGULARIZATION, + "ln_mean": MetricCategory.REGULARIZATION, + "ln_var": MetricCategory.REGULARIZATION, + "l1_penalty": MetricCategory.REGULARIZATION, + "l2_penalty": MetricCategory.REGULARIZATION, + + # SYSTEM + "iter": MetricCategory.SYSTEM, + "iteration": MetricCategory.SYSTEM, + "step": MetricCategory.SYSTEM, + "total": MetricCategory.SYSTEM, + "epoch": MetricCategory.SYSTEM, + "batch": MetricCategory.SYSTEM, + "batch_idx": MetricCategory.SYSTEM, + "global_step": MetricCategory.SYSTEM, + "time": MetricCategory.SYSTEM, + "dt": MetricCategory.SYSTEM, + "elapsed": MetricCategory.SYSTEM, + "wall_time": MetricCategory.SYSTEM, + "timestamp": MetricCategory.SYSTEM, + "hooked_layers": MetricCategory.SYSTEM, + "input_tokens": MetricCategory.SYSTEM, + "predicted_class": MetricCategory.TRAINING_DYNAMICS, + + # MODEL INFO + "params": MetricCategory.WEIGHT_DYNAMICS, + "num_params": MetricCategory.WEIGHT_DYNAMICS, + "total_params": MetricCategory.WEIGHT_DYNAMICS, + "trainable_params": MetricCategory.WEIGHT_DYNAMICS, + "parameters": MetricCategory.WEIGHT_DYNAMICS, + "model_size": MetricCategory.WEIGHT_DYNAMICS, + + # INFERENCE METRICS + "confidence": MetricCategory.TRAINING_DYNAMICS, + "similarity": MetricCategory.TRAINING_DYNAMICS, + "score": MetricCategory.TRAINING_DYNAMICS, + "prob": MetricCategory.TRAINING_DYNAMICS, + "probability": MetricCategory.TRAINING_DYNAMICS, + "entropy": MetricCategory.ATTENTION_MECHANICS, + "latency": MetricCategory.MEMORY_COMPUTE, + "inference_time": MetricCategory.MEMORY_COMPUTE, + "input_len": MetricCategory.DATA_PIPELINE, + "output_len": MetricCategory.DATA_PIPELINE, + + # OBSERVATION SYSTEM METRICS + "hooked_modules": MetricCategory.SYSTEM, + "total_layers": MetricCategory.SYSTEM, + "sample_rate": MetricCategory.SYSTEM, + "layer_num": MetricCategory.SYSTEM, + "max_depth": MetricCategory.SYSTEM, + "return_code": MetricCategory.SYSTEM, + "pid": MetricCategory.SYSTEM, + "max_iterations": MetricCategory.SYSTEM, + "total_iterations": MetricCategory.SYSTEM, + "iterations": MetricCategory.SYSTEM, + + # GPU/VRAM + "vram_gb": MetricCategory.MEMORY_COMPUTE, + "gpu_count": MetricCategory.MEMORY_COMPUTE, + "gpu_memory_gb": MetricCategory.MEMORY_COMPUTE, +} + +# Patterns for dynamic metric name matching +METRIC_PATTERNS: List[Tuple[str, MetricCategory]] = [ + (r".*loss.*", MetricCategory.TRAINING_DYNAMICS), + (r".*acc.*", MetricCategory.TRAINING_DYNAMICS), + (r".*accuracy.*", MetricCategory.TRAINING_DYNAMICS), + (r".*perplexity.*", MetricCategory.TRAINING_DYNAMICS), + (r".*lr.*", MetricCategory.TRAINING_DYNAMICS), + (r".*learning_rate.*", MetricCategory.TRAINING_DYNAMICS), + (r".*grad.*norm.*", MetricCategory.GRADIENT_HEALTH), + (r".*gradient.*", MetricCategory.GRADIENT_HEALTH), + (r".*weight.*norm.*", MetricCategory.WEIGHT_DYNAMICS), + (r".*param.*norm.*", MetricCategory.WEIGHT_DYNAMICS), + (r".*activation.*", MetricCategory.ACTIVATION_FLOW), + (r".*attention.*", MetricCategory.ATTENTION_MECHANICS), + (r".*attn.*", MetricCategory.ATTENTION_MECHANICS), + (r".*memory.*", MetricCategory.MEMORY_COMPUTE), + (r".*gpu.*", MetricCategory.MEMORY_COMPUTE), + (r".*mfu.*", MetricCategory.MEMORY_COMPUTE), + (r".*adam.*", MetricCategory.OPTIMIZATION_STATE), + (r".*momentum.*", MetricCategory.OPTIMIZATION_STATE), + (r".*overfit.*", MetricCategory.CONVERGENCE_SIGNALS), + (r".*plateau.*", MetricCategory.CONVERGENCE_SIGNALS), + (r".*data.*time.*", MetricCategory.DATA_PIPELINE), + (r".*batch.*time.*", MetricCategory.DATA_PIPELINE), + (r".*dropout.*", MetricCategory.REGULARIZATION), + (r".*bn_.*", MetricCategory.REGULARIZATION), + (r".*ln_.*", MetricCategory.REGULARIZATION), + (r".*iter.*", MetricCategory.SYSTEM), + (r".*epoch.*", MetricCategory.SYSTEM), + (r".*step.*", MetricCategory.SYSTEM), + (r".*time.*", MetricCategory.SYSTEM), + (r".*_ms$", MetricCategory.SYSTEM), + (r".*duration.*", MetricCategory.SYSTEM), +] + + +def classify_metric(name: str) -> MetricCategory: + """Classify a metric name into its category.""" + name_lower = name.lower() + + # Direct lookup + if name_lower in METRIC_TAXONOMY: + return METRIC_TAXONOMY[name_lower] + + # Pattern matching + for pattern, category in METRIC_PATTERNS: + if re.match(pattern, name_lower): + return category + + return MetricCategory.UNKNOWN + + +# ============================================================================= +# METRIC HEALTH THRESHOLDS (Industry Standards) +# ============================================================================= + +@dataclass +class MetricHealthSpec: + """Specification for healthy metric ranges.""" + name: str + category: MetricCategory + healthy_min: Optional[float] = None + healthy_max: Optional[float] = None + critical_min: Optional[float] = None + critical_max: Optional[float] = None + expected_trend: Optional[str] = None # 'falling', 'rising', 'stable' + + def is_healthy(self, value: float) -> bool: + if self.healthy_min is not None and value < self.healthy_min: + return False + if self.healthy_max is not None and value > self.healthy_max: + return False + return True + + def is_critical(self, value: float) -> bool: + if self.critical_min is not None and value < self.critical_min: + return True + if self.critical_max is not None and value > self.critical_max: + return True + return False + + +# Industry-standard health thresholds +HEALTH_SPECS: Dict[str, MetricHealthSpec] = { + "loss": MetricHealthSpec( + name="loss", + category=MetricCategory.TRAINING_DYNAMICS, + healthy_max=10.0, + critical_max=100.0, + expected_trend="falling", + ), + "grad_norm": MetricHealthSpec( + name="grad_norm", + category=MetricCategory.GRADIENT_HEALTH, + healthy_min=1e-7, + healthy_max=10.0, + critical_min=1e-10, # Vanishing + critical_max=1000.0, # Exploding + ), + "lr": MetricHealthSpec( + name="lr", + category=MetricCategory.TRAINING_DYNAMICS, + healthy_min=1e-8, + healthy_max=1.0, + critical_max=10.0, + ), + "mfu": MetricHealthSpec( + name="mfu", + category=MetricCategory.MEMORY_COMPUTE, + healthy_min=0.1, # 10% utilization minimum + healthy_max=1.0, + ), + "dead_relu_pct": MetricHealthSpec( + name="dead_relu_pct", + category=MetricCategory.ACTIVATION_FLOW, + healthy_max=0.3, # 30% dead is concerning + critical_max=0.7, # 70% dead is critical + ), + "train_val_gap": MetricHealthSpec( + name="train_val_gap", + category=MetricCategory.CONVERGENCE_SIGNALS, + healthy_max=0.5, # Gap shouldn't exceed 50% of train loss + critical_max=2.0, # Severe overfitting + ), +} + + +@dataclass +class MetricSeries: + """A time series of a single metric with category awareness.""" + name: str + category: MetricCategory = field(default=MetricCategory.UNKNOWN) + values: List[float] = field(default_factory=list) + timestamps: List[float] = field(default_factory=list) + event_ids: List[str] = field(default_factory=list) + + def __post_init__(self): + if self.category == MetricCategory.UNKNOWN: + self.category = classify_metric(self.name) + + @property + def count(self) -> int: + return len(self.values) + + @property + def current(self) -> Optional[float]: + return self.values[-1] if self.values else None + + @property + def previous(self) -> Optional[float]: + return self.values[-2] if len(self.values) >= 2 else None + + @property + def delta(self) -> Optional[float]: + """Change from previous to current.""" + if len(self.values) >= 2: + return self.values[-1] - self.values[-2] + return None + + @property + def delta_pct(self) -> Optional[float]: + """Percentage change from previous to current.""" + if len(self.values) >= 2 and self.values[-2] != 0: + return (self.values[-1] - self.values[-2]) / abs(self.values[-2]) + return None + + @property + def mean(self) -> Optional[float]: + return sum(self.values) / len(self.values) if self.values else None + + @property + def std(self) -> Optional[float]: + if len(self.values) < 2: + return None + mean = self.mean + variance = sum((x - mean) ** 2 for x in self.values) / len(self.values) + return math.sqrt(variance) + + @property + def min(self) -> Optional[float]: + return min(self.values) if self.values else None + + @property + def max(self) -> Optional[float]: + return max(self.values) if self.values else None + + @property + def range(self) -> Optional[float]: + if self.values: + return self.max - self.min + return None + + def moving_average(self, window: int = 5) -> Optional[float]: + """Compute moving average over last N values.""" + if len(self.values) < window: + return self.mean + return sum(self.values[-window:]) / window + + def rate_of_change(self, window: int = 5) -> Optional[float]: + """Average rate of change over last N values.""" + if len(self.values) < 2: + return None + window = min(window, len(self.values)) + recent = self.values[-window:] + deltas = [recent[i] - recent[i-1] for i in range(1, len(recent))] + return sum(deltas) / len(deltas) if deltas else None + + def is_anomaly(self, threshold_std: float = 2.0) -> bool: + """Is current value anomalous (outside N standard deviations)?""" + if len(self.values) < 5 or self.std is None or self.std == 0: + return False + return abs(self.values[-1] - self.mean) > threshold_std * self.std + + def trend(self, window: int = 10) -> str: + """Determine trend: 'rising', 'falling', 'stable', 'volatile'.""" + if len(self.values) < 3: + return "unknown" + + window = min(window, len(self.values)) + recent = self.values[-window:] + deltas = [recent[i] - recent[i-1] for i in range(1, len(recent))] + + positive = sum(1 for d in deltas if d > 0) + negative = sum(1 for d in deltas if d < 0) + + if positive > 0.7 * len(deltas): + return "rising" + elif negative > 0.7 * len(deltas): + return "falling" + elif self.std and self.mean and self.std > 0.1 * abs(self.mean): + return "volatile" + else: + return "stable" + + def health_status(self) -> str: + """Check health against industry standards. Returns 'healthy', 'warning', 'critical', 'unknown'.""" + if self.current is None: + return "unknown" + + name_lower = self.name.lower() + if name_lower in HEALTH_SPECS: + spec = HEALTH_SPECS[name_lower] + if spec.is_critical(self.current): + return "critical" + if not spec.is_healthy(self.current): + return "warning" + return "healthy" + + # Default heuristics for unknown metrics + if self.is_anomaly(threshold_std=3.0): + return "critical" + if self.is_anomaly(threshold_std=2.0): + return "warning" + return "healthy" + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "category": self.category.name, + "count": self.count, + "current": self.current, + "delta": self.delta, + "delta_pct": self.delta_pct, + "mean": self.mean, + "std": self.std, + "min": self.min, + "max": self.max, + "trend": self.trend(), + "health": self.health_status(), + "is_anomaly": self.is_anomaly(), + "rate_of_change": self.rate_of_change(), + } + + +@dataclass +class Anomaly: + """A detected anomaly in the metric stream.""" + metric_name: str + category: MetricCategory + event_id: str + timestamp: float + value: float + expected_range: Tuple[float, float] # (low, high) + deviation_std: float + severity: str # 'minor', 'major', 'critical' + + +@dataclass +class Correlation: + """A detected correlation between two metrics.""" + metric_a: str + metric_b: str + category_a: MetricCategory + category_b: MetricCategory + coefficient: float # -1 to 1 + strength: str # 'weak', 'moderate', 'strong' + direction: str # 'positive', 'negative' + + +@dataclass +class ThresholdCrossing: + """A metric crossing a significant threshold.""" + metric_name: str + category: MetricCategory + event_id: str + timestamp: float + old_value: float + new_value: float + threshold: float + direction: str # 'above', 'below' + + +class MetricsEngine: + """ + Quantification engine for the event stream. + + Extracts numeric metrics from events, tracks them over time, + detects anomalies, correlations, and threshold crossings. + + Does NOT interpret or explain. Provides raw quantified data + for human or AI observers to divine meaning from. + + Example: + >>> engine = MetricsEngine(graph) + >>> engine.ingest(event) + >>> + >>> # Get metric statistics + >>> loss = engine.get_metric("loss") + >>> print(f"Loss: {loss.current} (delta: {loss.delta}, trend: {loss.trend()})") + >>> + >>> # Get anomalies + >>> for anomaly in engine.anomalies: + ... print(f"ANOMALY: {anomaly.metric_name} = {anomaly.value}") + >>> + >>> # Get correlations + >>> for corr in engine.get_correlations(): + ... print(f"{corr.metric_a} ~ {corr.metric_b}: {corr.coefficient:.2f}") + """ + + def __init__(self, graph: Optional[CausationGraph] = None): + self.graph = graph + self._metrics: Dict[str, MetricSeries] = {} + self._anomalies: List[Anomaly] = [] + self._threshold_crossings: List[ThresholdCrossing] = [] + self._event_count = 0 + + # Configurable thresholds + self.anomaly_std_threshold = 2.5 + self.correlation_min_samples = 10 + + # Known significant thresholds for ML metrics + self._known_thresholds = { + "loss": [0.1, 0.5, 1.0, 2.0, 5.0, 10.0], + "accuracy": [0.5, 0.8, 0.9, 0.95, 0.99], + "lr": [1e-5, 1e-4, 1e-3, 1e-2, 0.1], + "learning_rate": [1e-5, 1e-4, 1e-3, 1e-2, 0.1], + "grad_norm": [0.1, 1.0, 10.0, 100.0], + "gradient_norm": [0.1, 1.0, 10.0, 100.0], + } + + def ingest(self, event: Event) -> Dict[str, MetricSeries]: + """ + Ingest an event and extract/track all numeric metrics. + + Returns dict of updated metric series. + """ + self._event_count += 1 + updated = {} + + for key, value in event.data.items(): + if isinstance(value, (int, float)) and not isinstance(value, bool): + category = classify_metric(key) + + if math.isnan(value) or math.isinf(value): + # Track NaN/Inf as anomalies but don't add to series + self._anomalies.append(Anomaly( + metric_name=key, + category=category, + event_id=event.event_id, + timestamp=event.timestamp, + value=value, + expected_range=(0, 0), + deviation_std=float('inf'), + severity='critical', + )) + continue + + # Get or create metric series with proper category + if key not in self._metrics: + self._metrics[key] = MetricSeries(name=key, category=category) + + series = self._metrics[key] + old_value = series.current + + # Add new value + series.values.append(float(value)) + series.timestamps.append(event.timestamp) + series.event_ids.append(event.event_id) + + # Check for anomaly + if series.is_anomaly(self.anomaly_std_threshold): + deviation = abs(value - series.mean) / series.std if series.std else 0 + severity = 'critical' if deviation > 4 else 'major' if deviation > 3 else 'minor' + self._anomalies.append(Anomaly( + metric_name=key, + category=category, + event_id=event.event_id, + timestamp=event.timestamp, + value=value, + expected_range=( + series.mean - 2*series.std, + series.mean + 2*series.std + ), + deviation_std=deviation, + severity=severity, + )) + + # Check for threshold crossing + if old_value is not None: + self._check_threshold_crossing( + key, event.event_id, event.timestamp, old_value, value + ) + + updated[key] = series + + return updated + + def _check_threshold_crossing( + self, + metric: str, + event_id: str, + timestamp: float, + old_value: float, + new_value: float + ): + """Check if a metric crossed a known threshold.""" + thresholds = self._known_thresholds.get(metric, []) + category = classify_metric(metric) + + for threshold in thresholds: + # Crossed upward + if old_value < threshold <= new_value: + self._threshold_crossings.append(ThresholdCrossing( + metric_name=metric, + category=category, + event_id=event_id, + timestamp=timestamp, + old_value=old_value, + new_value=new_value, + threshold=threshold, + direction='above', + )) + # Crossed downward + elif old_value > threshold >= new_value: + self._threshold_crossings.append(ThresholdCrossing( + metric_name=metric, + category=category, + event_id=event_id, + timestamp=timestamp, + old_value=old_value, + new_value=new_value, + threshold=threshold, + direction='below', + )) + + def get_metric(self, name: str) -> Optional[MetricSeries]: + """Get a metric series by name.""" + return self._metrics.get(name) + + @property + def metrics(self) -> Dict[str, MetricSeries]: + """All tracked metrics.""" + return self._metrics + + @property + def metric_names(self) -> List[str]: + """Names of all tracked metrics.""" + return list(self._metrics.keys()) + + @property + def anomalies(self) -> List[Anomaly]: + """All detected anomalies.""" + return self._anomalies + + @property + def recent_anomalies(self) -> List[Anomaly]: + """Anomalies from last 10 events.""" + if not self._anomalies: + return [] + recent_ids = set() + for series in self._metrics.values(): + recent_ids.update(series.event_ids[-10:]) + return [a for a in self._anomalies if a.event_id in recent_ids] + + @property + def threshold_crossings(self) -> List[ThresholdCrossing]: + """All threshold crossings.""" + return self._threshold_crossings + + def get_correlations(self, min_coefficient: float = 0.5) -> List[Correlation]: + """ + Compute correlations between all metric pairs. + + Returns correlations with |coefficient| >= min_coefficient. + """ + correlations = [] + metric_names = list(self._metrics.keys()) + + for i, name_a in enumerate(metric_names): + series_a = self._metrics[name_a] + for name_b in metric_names[i+1:]: + series_b = self._metrics[name_b] + coef = self._pearson_correlation(name_a, name_b) + if coef is not None and abs(coef) >= min_coefficient: + strength = 'strong' if abs(coef) > 0.8 else 'moderate' if abs(coef) > 0.5 else 'weak' + direction = 'positive' if coef > 0 else 'negative' + correlations.append(Correlation( + metric_a=name_a, + metric_b=name_b, + category_a=series_a.category, + category_b=series_b.category, + coefficient=coef, + strength=strength, + direction=direction, + )) + + return sorted(correlations, key=lambda c: abs(c.coefficient), reverse=True) + + def _pearson_correlation(self, name_a: str, name_b: str) -> Optional[float]: + """Compute Pearson correlation between two metrics.""" + series_a = self._metrics.get(name_a) + series_b = self._metrics.get(name_b) + + if not series_a or not series_b: + return None + + # Need enough samples + if series_a.count < self.correlation_min_samples or series_b.count < self.correlation_min_samples: + return None + + # Align by taking min length + n = min(series_a.count, series_b.count) + a = series_a.values[-n:] + b = series_b.values[-n:] + + # Compute correlation + mean_a = sum(a) / n + mean_b = sum(b) / n + + numerator = sum((a[i] - mean_a) * (b[i] - mean_b) for i in range(n)) + + var_a = sum((x - mean_a) ** 2 for x in a) + var_b = sum((x - mean_b) ** 2 for x in b) + + denominator = math.sqrt(var_a * var_b) + + if denominator == 0: + return None + + return numerator / denominator + + def summary(self) -> Dict[str, Any]: + """Get a summary of all metrics and detections.""" + return { + "event_count": self._event_count, + "metric_count": len(self._metrics), + "metrics": {name: series.to_dict() for name, series in self._metrics.items()}, + "metrics_by_category": self.metrics_by_category_summary(), + "anomaly_count": len(self._anomalies), + "recent_anomalies": [ + {"metric": a.metric_name, "category": a.category.name, "value": a.value, "severity": a.severity} + for a in self.recent_anomalies + ], + "threshold_crossings": len(self._threshold_crossings), + "correlations": [ + {"a": c.metric_a, "b": c.metric_b, "r": c.coefficient, + "cat_a": c.category_a.name, "cat_b": c.category_b.name} + for c in self.get_correlations()[:5] # Top 5 + ], + "health_status": self.health_summary(), + } + + # ========================================================================= + # CATEGORY-AWARE QUERIES + # ========================================================================= + + def get_metrics_by_category(self, category: MetricCategory) -> Dict[str, MetricSeries]: + """Get all metrics in a specific category.""" + return { + name: series for name, series in self._metrics.items() + if series.category == category + } + + def metrics_by_category_summary(self) -> Dict[str, Dict[str, Any]]: + """Get metric count and names grouped by category.""" + by_cat: Dict[str, Dict[str, Any]] = {} + for name, series in self._metrics.items(): + cat_name = series.category.name + if cat_name not in by_cat: + by_cat[cat_name] = {"count": 0, "metrics": [], "health": []} + by_cat[cat_name]["count"] += 1 + by_cat[cat_name]["metrics"].append(name) + by_cat[cat_name]["health"].append(series.health_status()) + return by_cat + + def get_training_metrics(self) -> Dict[str, MetricSeries]: + """Convenience: get all TRAINING_DYNAMICS metrics.""" + return self.get_metrics_by_category(MetricCategory.TRAINING_DYNAMICS) + + def get_gradient_metrics(self) -> Dict[str, MetricSeries]: + """Convenience: get all GRADIENT_HEALTH metrics.""" + return self.get_metrics_by_category(MetricCategory.GRADIENT_HEALTH) + + def get_memory_metrics(self) -> Dict[str, MetricSeries]: + """Convenience: get all MEMORY_COMPUTE metrics.""" + return self.get_metrics_by_category(MetricCategory.MEMORY_COMPUTE) + + def get_convergence_metrics(self) -> Dict[str, MetricSeries]: + """Convenience: get all CONVERGENCE_SIGNALS metrics.""" + return self.get_metrics_by_category(MetricCategory.CONVERGENCE_SIGNALS) + + def health_summary(self) -> Dict[str, Any]: + """Get overall health status of all metrics.""" + statuses = {"healthy": 0, "warning": 0, "critical": 0, "unknown": 0} + issues = [] + + for name, series in self._metrics.items(): + status = series.health_status() + statuses[status] += 1 + if status in ("warning", "critical"): + issues.append({ + "metric": name, + "category": series.category.name, + "status": status, + "value": series.current, + "trend": series.trend(), + }) + + overall = "critical" if statuses["critical"] > 0 else \ + "warning" if statuses["warning"] > 0 else "healthy" + + return { + "overall": overall, + "counts": statuses, + "issues": issues, + } + + def get_cross_category_correlations(self) -> List[Correlation]: + """Get correlations between metrics in different categories.""" + all_corr = self.get_correlations(min_coefficient=0.3) + return [c for c in all_corr if c.category_a != c.category_b] + + def get_category_coverage(self) -> Dict[str, bool]: + """Check which metric categories are being tracked.""" + tracked = {series.category for series in self._metrics.values()} + return {cat.name: cat in tracked for cat in MetricCategory} + + # ========================================================================= + # TRIAGE SYSTEM - Common Sense Diagnostics (Occam's Razor) + # ========================================================================= + # + # Five questions that matter: + # 1. Is training working? (loss trend) + # 2. Is it about to explode? (gradient health) + # 3. Am I wasting compute? (efficiency) + # 4. Am I overfitting? (generalization gap) + # 5. What broke and why? (anomaly + correlation) + # + + def triage(self) -> Dict[str, Any]: + """ + Quick diagnostic: Is training healthy? What's wrong? + + Returns a simple, actionable assessment. + Occam's Razor: simplest useful answer. + """ + diagnosis = { + "status": "LISTENING", # Not UNKNOWN - we're actively waiting + "confidence": 0.0, + "checks": {}, + "action": "Collecting initial metrics...", + "details": [], + } + + checks_passed = 0 + checks_total = 0 + + # CHECK 1: Is loss going down? + loss_check = self._check_loss_progress() + diagnosis["checks"]["loss_progress"] = loss_check + checks_total += 1 + if loss_check["ok"]: + checks_passed += 1 + + # CHECK 2: Are gradients healthy? + grad_check = self._check_gradient_health() + diagnosis["checks"]["gradient_health"] = grad_check + checks_total += 1 + if grad_check["ok"]: + checks_passed += 1 + + # CHECK 3: Am I using compute efficiently? + efficiency_check = self._check_efficiency() + diagnosis["checks"]["efficiency"] = efficiency_check + checks_total += 1 + if efficiency_check["ok"]: + checks_passed += 1 + + # CHECK 4: Am I overfitting? + overfit_check = self._check_overfitting() + diagnosis["checks"]["overfitting"] = overfit_check + checks_total += 1 + if overfit_check["ok"]: + checks_passed += 1 + + # CHECK 5: Any anomalies pointing to root cause? + anomaly_check = self._check_anomalies() + diagnosis["checks"]["anomalies"] = anomaly_check + checks_total += 1 + if anomaly_check["ok"]: + checks_passed += 1 + + # Overall status + diagnosis["confidence"] = checks_passed / checks_total if checks_total > 0 else 0 + + if checks_passed == checks_total: + diagnosis["status"] = "HEALTHY" + diagnosis["action"] = "Training looks good. Continue monitoring." + elif checks_passed >= checks_total * 0.6: + diagnosis["status"] = "WARNING" + # Find what's wrong + issues = [k for k, v in diagnosis["checks"].items() if not v["ok"]] + diagnosis["action"] = f"Review: {', '.join(issues)}" + else: + diagnosis["status"] = "CRITICAL" + diagnosis["action"] = "Stop and investigate. Multiple issues detected." + + # Collect all details + for check_name, check_result in diagnosis["checks"].items(): + if check_result.get("detail"): + diagnosis["details"].append(f"{check_name}: {check_result['detail']}") + + return diagnosis + + def _check_loss_progress(self) -> Dict[str, Any]: + """Is loss decreasing as expected?""" + # Find loss metric (try common names) + loss_series = None + for name in ["loss", "train_loss", "nll_loss", "ce_loss"]: + if name in self._metrics: + loss_series = self._metrics[name] + break + + if loss_series is None or loss_series.count < 3: + return {"ok": True, "detail": "Waiting for loss metrics (need 3+)", "status": "waiting"} + + trend = loss_series.trend() + roc = loss_series.rate_of_change() + + if trend == "falling": + return {"ok": True, "detail": f"Loss falling (Δ={roc:.4f}/step)", "status": "good"} + elif trend == "stable" and loss_series.current < 1.0: + return {"ok": True, "detail": f"Loss stable at {loss_series.current:.4f}", "status": "converged"} + elif trend == "rising": + return {"ok": False, "detail": f"Loss RISING! Current: {loss_series.current:.4f}", "status": "diverging"} + elif trend == "volatile": + return {"ok": False, "detail": f"Loss unstable (std={loss_series.std:.4f})", "status": "unstable"} + else: + return {"ok": True, "detail": f"Loss: {loss_series.current:.4f} (trend unclear)", "status": "stable"} + + def _check_gradient_health(self) -> Dict[str, Any]: + """Are gradients in a healthy range?""" + grad_series = None + for name in ["grad_norm", "gradient_norm", "global_grad_norm"]: + if name in self._metrics: + grad_series = self._metrics[name] + break + + if grad_series is None or grad_series.count < 2: + return {"ok": True, "detail": "Waiting for grad_norm metrics", "status": "waiting"} + + current = grad_series.current + + # Vanishing gradients + if current < 1e-7: + return {"ok": False, "detail": f"VANISHING gradients: {current:.2e}", "status": "vanishing"} + + # Exploding gradients + if current > 100: + return {"ok": False, "detail": f"EXPLODING gradients: {current:.2f}", "status": "exploding"} + + # Healthy range + if 1e-5 < current < 10: + return {"ok": True, "detail": f"Gradients healthy: {current:.4f}", "status": "healthy"} + + # Warning zone + return {"ok": True, "detail": f"Gradients marginal: {current:.4f}", "status": "marginal"} + + def _check_efficiency(self) -> Dict[str, Any]: + """Am I using compute efficiently?""" + # Check MFU (Model FLOP Utilization) + mfu_series = self._metrics.get("mfu") + if mfu_series and mfu_series.count > 0: + mfu = mfu_series.current + if mfu < 0.1: + return {"ok": False, "detail": f"Low GPU utilization: {mfu*100:.1f}%", "status": "inefficient"} + elif mfu < 0.3: + return {"ok": True, "detail": f"Moderate efficiency: {mfu*100:.1f}%", "status": "moderate"} + else: + return {"ok": True, "detail": f"Good efficiency: {mfu*100:.1f}%", "status": "efficient"} + + # Fallback: check timing + time_series = self._metrics.get("dt") or self._metrics.get("time") or self._metrics.get("batch_time") + if time_series and time_series.count > 2: + trend = time_series.trend() + if trend == "rising": + return {"ok": False, "detail": "Step time increasing (slowdown)", "status": "degrading"} + return {"ok": True, "detail": f"Step time: {time_series.current:.3f}s", "status": "stable"} + + return {"ok": True, "detail": "Need mfu or dt/time metrics", "status": "waiting"} + + def _check_overfitting(self) -> Dict[str, Any]: + """Is model overfitting?""" + train_loss = None + val_loss = None + + # Find train and val loss + for name in ["loss", "train_loss"]: + if name in self._metrics: + train_loss = self._metrics[name] + break + + for name in ["val_loss", "eval_loss", "test_loss"]: + if name in self._metrics: + val_loss = self._metrics[name] + break + + if train_loss is None or val_loss is None: + return {"ok": True, "detail": "Need train_loss + val_loss to check", "status": "waiting"} + + if train_loss.count < 3 or val_loss.count < 3: + return {"ok": True, "detail": f"Collecting ({train_loss.count}/3 train, {val_loss.count}/3 val)", "status": "waiting"} + + gap = val_loss.current - train_loss.current + gap_pct = gap / train_loss.current if train_loss.current > 0 else 0 + + # Check if gap is widening + train_trend = train_loss.trend() + val_trend = val_loss.trend() + + if train_trend == "falling" and val_trend == "rising": + return {"ok": False, "detail": f"OVERFITTING: train↓ val↑ (gap={gap:.4f})", "status": "overfitting"} + + if gap_pct > 0.5: # Val loss 50% higher than train + return {"ok": False, "detail": f"Large generalization gap: {gap_pct*100:.1f}%", "status": "high_gap"} + + if gap_pct > 0.2: + return {"ok": True, "detail": f"Moderate gap: {gap_pct*100:.1f}%", "status": "moderate_gap"} + + return {"ok": True, "detail": f"Good generalization (gap={gap:.4f})", "status": "healthy"} + + def _check_anomalies(self) -> Dict[str, Any]: + """Any recent anomalies that need attention?""" + recent = self.recent_anomalies + + if not recent: + return {"ok": True, "detail": "No anomalies", "status": "clean"} + + critical = [a for a in recent if a.severity == "critical"] + major = [a for a in recent if a.severity == "major"] + + if critical: + names = list(set(a.metric_name for a in critical)) + return {"ok": False, "detail": f"CRITICAL anomalies in: {', '.join(names)}", "status": "critical"} + + if major: + names = list(set(a.metric_name for a in major)) + return {"ok": False, "detail": f"Major anomalies in: {', '.join(names)}", "status": "major"} + + return {"ok": True, "detail": f"{len(recent)} minor anomalies", "status": "minor"} + + def quick_status(self) -> str: + """One-line status for dashboards.""" + t = self.triage() + return f"[{t['status']}] {t['action']} (confidence: {t['confidence']*100:.0f}%)" + + def __repr__(self) -> str: + return f"" diff --git a/cascade/analysis/tracer.py b/cascade/analysis/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..5842fb4bc74d5434e310d0093520eab33d71617e --- /dev/null +++ b/cascade/analysis/tracer.py @@ -0,0 +1,487 @@ +""" +Cascade Analysis - Bidirectional Causation Tracer. + +Trace cause-effect chains forwards and backwards through time. +Find root causes. Predict cascading effects. +""" + +from typing import List, Dict, Any, Optional, Set +from collections import deque +from dataclasses import dataclass, field + +from cascade.core.event import Event, CausationLink, CausationChain +from cascade.core.graph import CausationGraph + + +@dataclass +class RootCauseAnalysis: + """Results of a root cause analysis.""" + target_event: Event + root_causes: List[Event] + chains: List[CausationChain] + deepest_depth: int = 0 + narrative: str = "" + + +@dataclass +class ImpactAnalysis: + """Results of an impact/forward analysis.""" + source_event: Event + effects: List[Event] + chains: List[CausationChain] + total_impact_count: int = 0 + severity_score: float = 0.0 + narrative: str = "" + + +@dataclass +class CascadePrediction: + """Prediction of likely cascade from an event.""" + source_event: Event + predicted_effects: List[Dict[str, Any]] # [{event_type, probability, time_estimate}, ...] + risk_score: float = 0.0 + intervention_points: List[str] = field(default_factory=list) + narrative: str = "" + + +class Tracer: + """ + Bidirectional causation tracer. + + Traces cause-effect chains through the causation graph: + - Backwards: "What caused this?" → find root causes + - Forwards: "What will this cause?" → predict cascades + + Example: + >>> tracer = Tracer(graph) + >>> + >>> # What caused this gradient explosion? + >>> causes = tracer.trace_backwards("evt_123") + >>> + >>> # What will this learning rate change cause? + >>> effects = tracer.trace_forwards("evt_456") + >>> + >>> # Deep root cause analysis + >>> roots = tracer.find_root_causes("evt_789") + """ + + def __init__(self, graph: CausationGraph): + """ + Initialize tracer with a causation graph. + + Args: + graph: The causation graph to trace through + """ + self.graph = graph + self._prediction_model = None # Future: ML model for predictions + + def trace_backwards(self, event_id: str, max_depth: int = 1000) -> List[CausationChain]: + """ + Trace causation backwards: what caused this event? + + Args: + event_id: ID of the event to trace from + max_depth: Maximum depth to trace (default: 1000 - effectively unlimited) + + Returns: + List of CausationChain objects, one per causal path found + """ + target = self.graph.get_event(event_id) + if not target: + return [] + + chains = [] + self._trace_backwards_recursive(event_id, [], [], max_depth, chains) + + # Sort by depth (longest chain first for root cause analysis) + chains.sort(key=lambda c: c.depth, reverse=True) + return chains + + def _trace_backwards_recursive( + self, + current_id: str, + current_events: List[Event], + current_links: List[CausationLink], + depth_remaining: int, + results: List[CausationChain], + visited: Optional[Set[str]] = None + ) -> None: + """Recursive helper for backwards tracing.""" + if visited is None: + visited = set() + + if current_id in visited: + return # Avoid cycles + visited.add(current_id) + + current_event = self.graph.get_event(current_id) + if not current_event: + return + + current_events = [current_event] + current_events + + if depth_remaining <= 0: + # Max depth reached, record this chain + if len(current_events) > 1: + results.append(self._build_chain(current_events, current_links)) + return + + causes = self.graph.get_causes(current_id) + + if not causes: + # This is a root - record the chain + if len(current_events) >= 1: + results.append(self._build_chain(current_events, current_links)) + return + + for cause in causes: + link = self.graph.get_link(cause.event_id, current_id) + new_links = [link] + current_links if link else current_links + + self._trace_backwards_recursive( + cause.event_id, + current_events, + new_links, + depth_remaining - 1, + results, + visited.copy() + ) + + def trace_forwards(self, event_id: str, max_depth: int = 1000) -> List[CausationChain]: + """ + Trace causation forwards: what will this event cause? + + Args: + event_id: ID of the event to trace from + max_depth: Maximum depth to trace (default: 1000 - effectively unlimited) + + Returns: + List of CausationChain objects, one per effect path found + """ + source = self.graph.get_event(event_id) + if not source: + return [] + + chains = [] + self._trace_forwards_recursive(event_id, [], [], max_depth, chains) + + # Sort by depth + chains.sort(key=lambda c: c.depth, reverse=True) + return chains + + def _trace_forwards_recursive( + self, + current_id: str, + current_events: List[Event], + current_links: List[CausationLink], + depth_remaining: int, + results: List[CausationChain], + visited: Optional[Set[str]] = None + ) -> None: + """Recursive helper for forwards tracing.""" + if visited is None: + visited = set() + + if current_id in visited: + return + visited.add(current_id) + + current_event = self.graph.get_event(current_id) + if not current_event: + return + + current_events = current_events + [current_event] + + if depth_remaining <= 0: + if len(current_events) > 1: + results.append(self._build_chain(current_events, current_links)) + return + + effects = self.graph.get_effects(current_id) + + if not effects: + # This is a leaf - record the chain + if len(current_events) >= 1: + results.append(self._build_chain(current_events, current_links)) + return + + for effect in effects: + link = self.graph.get_link(current_id, effect.event_id) + new_links = current_links + [link] if link else current_links + + self._trace_forwards_recursive( + effect.event_id, + current_events, + new_links, + depth_remaining - 1, + results, + visited.copy() + ) + + def find_root_causes(self, event_id: str, max_depth: int = 1000) -> RootCauseAnalysis: + """ + Deep root cause analysis: find the ultimate origins. + + Traces all the way back to find events with no causes. + + Args: + event_id: ID of the event to analyze + max_depth: Maximum depth to search (default: 1000 - effectively unlimited) + + Returns: + RootCauseAnalysis with root causes and narrative + """ + target = self.graph.get_event(event_id) + if not target: + return RootCauseAnalysis( + target_event=None, + root_causes=[], + chains=[], + ) + + chains = self.trace_backwards(event_id, max_depth) + + # Extract root causes (events at the start of chains) + root_causes = [] + seen = set() + for chain in chains: + if chain.events: + root = chain.events[0] + if root.event_id not in seen: + root_causes.append(root) + seen.add(root.event_id) + + # Build narrative + narrative = self._build_root_cause_narrative(target, root_causes, chains) + + return RootCauseAnalysis( + target_event=target, + root_causes=root_causes, + chains=chains, + deepest_depth=max(c.depth for c in chains) if chains else 0, + narrative=narrative, + ) + + def analyze_impact(self, event_id: str, max_depth: int = 1000) -> ImpactAnalysis: + """ + Impact analysis: what were ALL downstream effects? + + Traces forward to find everything this event set in motion. + + Args: + event_id: ID of the event to analyze + max_depth: Maximum depth to search (default: 1000 - effectively unlimited) + + Returns: + ImpactAnalysis with effects and severity score + """ + source = self.graph.get_event(event_id) + if not source: + return ImpactAnalysis( + source_event=None, + effects=[], + chains=[], + ) + + chains = self.trace_forwards(event_id, max_depth) + + # Extract all effects + effects = [] + seen = set() + for chain in chains: + for event in chain.events[1:]: # Skip source + if event.event_id not in seen: + effects.append(event) + seen.add(event.event_id) + + # Calculate severity + severity = self._calculate_impact_severity(source, effects) + + # Build narrative + narrative = self._build_impact_narrative(source, effects, chains) + + return ImpactAnalysis( + source_event=source, + effects=effects, + chains=chains, + total_impact_count=len(effects), + severity_score=severity, + narrative=narrative, + ) + + def predict_cascade(self, event_id: str) -> CascadePrediction: + """ + Predict likely cascade from this event. + + Uses learned patterns to forecast effects BEFORE they happen. + This is the "Minority Report" capability. + + Args: + event_id: ID of the event to predict from + + Returns: + CascadePrediction with risk scores and intervention points + """ + source = self.graph.get_event(event_id) + if not source: + return CascadePrediction( + source_event=None, + predicted_effects=[], + ) + + # Get historical patterns for this event type + similar_events = self.graph.get_events_by_type(source.event_type) + + # Count what typically follows - use all available history for better predictions + # No artificial cap - system learns from full history + effect_counts: Dict[str, int] = {} + analysis_window = similar_events # Full history, no slice + for similar in analysis_window: + effects = self.graph.get_effects(similar.event_id) + for effect in effects: + key = effect.event_type + effect_counts[key] = effect_counts.get(key, 0) + 1 + + # Convert to predictions + total = len(analysis_window) + predictions = [] + for event_type, count in sorted(effect_counts.items(), key=lambda x: -x[1]): + predictions.append({ + "event_type": event_type, + "probability": count / total if total > 0 else 0, + "historical_count": count, + }) + + # Calculate risk score + risk_score = self._calculate_risk_score(source, predictions) + + # Identify intervention points + intervention_points = self._find_intervention_points(source, predictions) + + return CascadePrediction( + source_event=source, + predicted_effects=predictions[:10], # Top 10 + risk_score=risk_score, + intervention_points=intervention_points, + narrative=f"Based on {total} similar events, predicting {len(predictions)} likely effects.", + ) + + def _build_chain(self, events: List[Event], links: List[CausationLink]) -> CausationChain: + """Build a CausationChain from events and links.""" + total_strength = 1.0 + for link in links: + total_strength *= link.strength + + return CausationChain( + events=events, + links=links, + total_strength=total_strength, + depth=len(links), + ) + + def _build_root_cause_narrative( + self, + target: Event, + roots: List[Event], + chains: List[CausationChain] + ) -> str: + """Build human-readable narrative for root cause analysis.""" + if not roots: + return f"No root causes found for {target.event_type}" + + lines = [f"Root cause analysis for {target.event_type}:"] + lines.append(f"Found {len(roots)} root cause(s) across {len(chains)} causal chain(s).") + lines.append("") + + for i, root in enumerate(roots[:5], 1): # Top 5 + lines.append(f"{i}. {root.component}/{root.event_type}") + if root.data: + key_data = list(root.data.items())[:3] + lines.append(f" Data: {dict(key_data)}") + + return "\n".join(lines) + + def _build_impact_narrative( + self, + source: Event, + effects: List[Event], + chains: List[CausationChain] + ) -> str: + """Build human-readable narrative for impact analysis.""" + if not effects: + return f"No downstream effects found for {source.event_type}" + + lines = [f"Impact analysis for {source.event_type}:"] + lines.append(f"Found {len(effects)} downstream effect(s).") + lines.append("") + + # Group by event type + by_type: Dict[str, int] = {} + for effect in effects: + by_type[effect.event_type] = by_type.get(effect.event_type, 0) + 1 + + for event_type, count in sorted(by_type.items(), key=lambda x: -x[1]): + lines.append(f" • {event_type}: {count} occurrence(s)") + + return "\n".join(lines) + + def _calculate_impact_severity(self, source: Event, effects: List[Event]) -> float: + """Calculate severity score for an impact (0.0 to 1.0).""" + if not effects: + return 0.0 + + # Factors: number of effects, types of effects + count_score = min(1.0, len(effects) / 20) # 20+ effects = max + + # High-severity event types + severe_types = {'error', 'anomaly', 'crash', 'failure', 'explosion'} + severe_count = sum(1 for e in effects if e.event_type in severe_types) + severity_score = min(1.0, severe_count / 5) + + return (count_score + severity_score) / 2 + + def _calculate_risk_score( + self, + source: Event, + predictions: List[Dict[str, Any]] + ) -> float: + """Calculate risk score for a cascade prediction.""" + if not predictions: + return 0.0 + + # High-risk event types + risky_types = {'error', 'anomaly', 'crash', 'failure', 'explosion', 'nan', 'overflow'} + + risk = 0.0 + for pred in predictions: + if pred["event_type"] in risky_types: + risk += pred["probability"] * 2 # Double weight for risky + else: + risk += pred["probability"] * 0.5 + + return min(1.0, risk) + + def _find_intervention_points( + self, + source: Event, + predictions: List[Dict[str, Any]] + ) -> List[str]: + """Identify points where intervention could prevent bad cascades.""" + points = [] + + # Look at source event data for intervention hints + if 'learning_rate' in source.data: + points.append("Reduce learning rate") + if 'gradient' in source.event_type.lower(): + points.append("Apply gradient clipping") + if source.data.get('loss', 0) > 10: + points.append("Check loss function / data") + + # Check predictions for severe outcomes + for pred in predictions: + if pred["event_type"] == "nan" and pred["probability"] > 0.3: + points.append("Enable NaN detection early stopping") + if pred["event_type"] == "overflow" and pred["probability"] > 0.3: + points.append("Apply gradient scaling") + + return points diff --git a/cascade/bridge.py b/cascade/bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..3302921aa8b1eb67746ed2346701a1562b262b86 --- /dev/null +++ b/cascade/bridge.py @@ -0,0 +1,265 @@ +""" +HuggingFace → IPFS Bridge + +Makes every CASCADE instance a node in the IPFS network. +Serves lattice content to DHT without running a full daemon. + +Uses js-ipfs HTTP API compatible endpoints via ipfs-http-client. +For HF Spaces, we use Helia (browser/Node IPFS) style serving. +""" + +import json +import hashlib +from pathlib import Path +from typing import Optional, Dict, Any +import threading +import time + +# Optional: for full IPFS integration +try: + import ipfshttpclient + HAS_IPFS_CLIENT = True +except ImportError: + HAS_IPFS_CLIENT = False + +from cascade.ipld import chain_to_ipld, chain_to_cid, encode_to_dag_cbor + + +class LatticeServer: + """ + Serves lattice content over IPFS-compatible protocols. + + Can run in multiple modes: + 1. Gateway mode: HTTP endpoints that mirror IPFS gateway API + 2. DHT mode: Announce content to IPFS DHT (needs daemon) + 3. Hybrid: Both + """ + + def __init__(self, lattice_dir: Path = None): + if lattice_dir is None: + # Try relative to this file first, then cwd + candidate = Path(__file__).resolve().parent.parent / "lattice" + if not candidate.exists(): + candidate = Path.cwd() / "lattice" + self.lattice_dir = candidate + else: + self.lattice_dir = lattice_dir + self.ipld_dir = self.lattice_dir / "ipld" + self._index: Dict[str, Path] = {} # CID -> file path + self._build_index() + + def _build_index(self): + """Index all known CIDs to their local files.""" + # Index CBOR files + if self.ipld_dir.exists(): + for cbor_file in self.ipld_dir.glob("*.cbor"): + ipld_json = cbor_file.with_suffix(".ipld.json") + if ipld_json.exists(): + meta = json.loads(ipld_json.read_text()) + # Try both 'cid' and '_cid' keys + cid = meta.get("cid") or meta.get("_cid") + if cid: + self._index[cid] = cbor_file + + # Index JSON chain files (compute CID on the fly) + for json_file in self.lattice_dir.glob("*.json"): + if json_file.name == "README.md": + continue + try: + chain_data = json.loads(json_file.read_text()) + cid = chain_to_cid(chain_data) + self._index[cid] = json_file + except: + pass + + print(f"Indexed {len(self._index)} CIDs") + + def resolve(self, cid: str) -> Optional[bytes]: + """Resolve a CID to its content.""" + if cid in self._index: + filepath = self._index[cid] + if filepath.suffix == ".cbor": + return filepath.read_bytes() + else: + # JSON file - return as CBOR for consistency + chain_data = json.loads(filepath.read_text()) + ipld_data = chain_to_ipld(chain_data) + return encode_to_dag_cbor(ipld_data) + return None + + def list_cids(self) -> list: + """List all available CIDs.""" + return list(self._index.keys()) + + def get_gateway_response(self, cid: str) -> tuple: + """ + Return (content, content_type, status_code) for gateway-style serving. + """ + content = self.resolve(cid) + if content: + return (content, "application/cbor", 200) + return (b"CID not found", "text/plain", 404) + + def announce_to_dht(self, ipfs_api: str = "/ip4/127.0.0.1/tcp/5001"): + """ + Announce all CIDs to IPFS DHT. + Requires running IPFS daemon. + """ + if not HAS_IPFS_CLIENT: + print("ipfshttpclient not installed. Run: pip install ipfshttpclient") + return + + try: + client = ipfshttpclient.connect(ipfs_api) + except Exception as e: + print(f"Could not connect to IPFS daemon: {e}") + print("Start daemon with: ipfs daemon") + return + + for cid, filepath in self._index.items(): + try: + # Add file to local IPFS node + if filepath.suffix == ".cbor": + result = client.add(str(filepath)) + print(f"Announced {filepath.name}: {result['Hash']}") + except Exception as e: + print(f"Failed to announce {cid}: {e}") + + def start_gateway(self, host: str = "0.0.0.0", port: int = 8080): + """ + Start a simple HTTP gateway for serving lattice content. + + Compatible with IPFS gateway URL format: + GET /ipfs/{cid} + """ + from http.server import HTTPServer, BaseHTTPRequestHandler + + server = self + + class GatewayHandler(BaseHTTPRequestHandler): + def do_GET(self): + # Parse /ipfs/{cid} or just /{cid} + path = self.path.strip("/") + if path.startswith("ipfs/"): + cid = path[5:] + else: + cid = path + + content, content_type, status = server.get_gateway_response(cid) + + self.send_response(status) + self.send_header("Content-Type", content_type) + self.send_header("Content-Length", len(content)) + self.send_header("Access-Control-Allow-Origin", "*") + self.end_headers() + self.wfile.write(content) + + def do_HEAD(self): + path = self.path.strip("/") + if path.startswith("ipfs/"): + cid = path[5:] + else: + cid = path + + _, content_type, status = server.get_gateway_response(cid) + + self.send_response(status) + self.send_header("Content-Type", content_type) + self.send_header("Access-Control-Allow-Origin", "*") + self.end_headers() + + def log_message(self, format, *args): + print(f"[Gateway] {args[0]}") + + httpd = HTTPServer((host, port), GatewayHandler) + print(f"Lattice gateway running at http://{host}:{port}") + print(f"Serving {len(self._index)} CIDs") + print(f"\nTry: http://localhost:{port}/ipfs/bafyreidixjlzdat7ex72foi6vm3vnskhzguovxj6ondbazrqks7v6ahmei") + httpd.serve_forever() + + +def create_gradio_gateway(): + """ + Create a Gradio interface that serves as IPFS gateway. + Suitable for HuggingFace Spaces deployment. + """ + try: + import gradio as gr + except ImportError: + print("Gradio not installed. Run: pip install gradio") + return None + + server = LatticeServer() + + def resolve_cid(cid: str) -> str: + """Resolve CID and return content as hex + JSON decode attempt.""" + content = server.resolve(cid.strip()) + if content is None: + return f"❌ CID not found: {cid}\n\nAvailable CIDs:\n" + "\n".join(server.list_cids()) + + # Try to decode as CBOR → JSON for display + try: + import dag_cbor + decoded = dag_cbor.decode(content) + return f"✓ Found! ({len(content)} bytes)\n\n{json.dumps(decoded, indent=2, default=str)}" + except: + return f"✓ Found! ({len(content)} bytes)\n\nRaw hex: {content.hex()[:200]}..." + + def list_all() -> str: + """List all available CIDs.""" + cids = server.list_cids() + lines = [f"=== Lattice Index ({len(cids)} chains) ===\n"] + for cid in cids: + filepath = server._index[cid] + lines.append(f"• {filepath.stem}") + lines.append(f" {cid}\n") + return "\n".join(lines) + + with gr.Blocks(title="CASCADE Lattice Gateway") as app: + gr.Markdown("# 🌐 CASCADE Lattice Gateway") + gr.Markdown("*The neural internetwork, content-addressed.*") + + with gr.Tab("Resolve CID"): + cid_input = gr.Textbox( + label="CID", + placeholder="bafyrei...", + value="bafyreidixjlzdat7ex72foi6vm3vnskhzguovxj6ondbazrqks7v6ahmei" + ) + resolve_btn = gr.Button("Resolve") + output = gr.Textbox(label="Content", lines=20) + resolve_btn.click(resolve_cid, inputs=cid_input, outputs=output) + + with gr.Tab("Browse Lattice"): + list_btn = gr.Button("List All CIDs") + list_output = gr.Textbox(label="Available Chains", lines=20) + list_btn.click(list_all, outputs=list_output) + + gr.Markdown(""" + --- + **What is this?** + + This gateway serves the CASCADE lattice — a cryptographic provenance network for AI agents. + + Every chain has a CID (Content IDentifier). Same content = same CID. Forever. + + - **Genesis**: `bafyreidixjlzdat7ex72foi6vm3vnskhzguovxj6ondbazrqks7v6ahmei` + - Protocol: [IPLD](https://ipld.io/) (InterPlanetary Linked Data) + """) + + return app + + +if __name__ == "__main__": + import sys + + if "--gradio" in sys.argv: + app = create_gradio_gateway() + if app: + app.launch() + elif "--announce" in sys.argv: + server = LatticeServer() + server.announce_to_dht() + else: + # Default: run HTTP gateway + server = LatticeServer() + server.start_gateway(port=8080) diff --git a/cascade/cli_main.py b/cascade/cli_main.py new file mode 100644 index 0000000000000000000000000000000000000000..81ed14d75e260c20a7dfaa6775a6984c27a66846 --- /dev/null +++ b/cascade/cli_main.py @@ -0,0 +1,851 @@ +""" +CASCADE CLI - Full-featured Rich TUI for cascade-ai. + +Exposes all CASCADE capabilities: +- Lattice: stats, list, inspect, chains, pin, export, watch +- Model: observe, fingerprint +- Data: entities, provenance, pii scan +- System: logs, analyze, ingest +- Proxy: start intercepting proxy +""" + +import argparse +import sys +import json +from pathlib import Path +from datetime import datetime + +# Rich imports with fallback +try: + from rich.console import Console + from rich.table import Table + from rich.panel import Panel + from rich.tree import Tree + from rich.progress import Progress, SpinnerColumn, TextColumn + from rich.text import Text + from rich.markdown import Markdown + from rich.syntax import Syntax + from rich import box + HAS_RICH = True +except ImportError: + HAS_RICH = False + +console = Console() if HAS_RICH else None + + +# ═══════════════════════════════════════════════════════════════════════════════ +# LATTICE COMMANDS +# ═══════════════════════════════════════════════════════════════════════════════ + +def cmd_stats(args): + """Show lattice statistics with Rich panels.""" + from cascade.observation import ObservationManager + + manager = ObservationManager() + stats = manager.get_stats() + + if HAS_RICH: + stats_table = Table(show_header=False, box=box.SIMPLE, padding=(0, 2)) + stats_table.add_column("Key", style="cyan") + stats_table.add_column("Value", style="green") + + stats_table.add_row("Genesis Root", f"[bold magenta]{stats['genesis_root']}[/]") + stats_table.add_row("", "") + stats_table.add_row("Total Observations", str(stats['total_observations'])) + stats_table.add_row(" └─ Model", str(stats['model_observations'])) + stats_table.add_row(" └─ Data", str(stats['data_observations'])) + stats_table.add_row(" └─ System", str(stats['system_observations'])) + stats_table.add_row("", "") + stats_table.add_row("Registered Models", str(stats['registered_models'])) + stats_table.add_row("Unique Models Observed", str(stats['unique_models'])) + + panel = Panel( + stats_table, + title="[bold cyan]CASCADE LATTICE[/]", + subtitle="[dim]The Neural Internetwork[/]", + border_style="cyan", + ) + console.print(panel) + else: + print(f""" +CASCADE LATTICE STATS +═════════════════════ +Genesis Root: {stats['genesis_root']} + +Observations: + Total: {stats['total_observations']} + Model: {stats['model_observations']} + Data: {stats['data_observations']} + System: {stats['system_observations']} + +Models: + Registered: {stats['registered_models']} + Observed: {stats['unique_models']} +""") + + +def cmd_list(args): + """List recent observations.""" + from cascade.observation import ObservationManager + + manager = ObservationManager() + observations = manager.list_observations(limit=args.limit) + + if not observations: + if HAS_RICH: + console.print("[yellow]No observations yet.[/]") + else: + print("No observations yet.") + return + + if HAS_RICH: + table = Table(title=f"Recent Observations", box=box.ROUNDED) + table.add_column("Type", style="cyan", width=8) + table.add_column("Source", style="white", max_width=40) + table.add_column("Merkle Root", style="magenta") + table.add_column("Time", style="dim") + + for obs in observations: + obs_type = obs.get('observation_type', '?')[:7] + source = obs.get('source_id', 'unknown')[:39] + merkle = obs.get('merkle_root', '?')[:16] + timestamp = obs.get('timestamp', '') + if timestamp: + try: + if isinstance(timestamp, (int, float)): + timestamp = datetime.fromtimestamp(timestamp).strftime('%H:%M:%S') + else: + timestamp = str(timestamp)[:8] + except: + timestamp = '?' + + table.add_row(obs_type, source, merkle, timestamp) + + console.print(table) + console.print(f"[dim]Showing {len(observations)} of {manager.get_stats()['total_observations']}[/]") + else: + print(f"\n{'TYPE':<8} {'SOURCE':<40} {'MERKLE ROOT':<20}") + print("─" * 70) + for obs in observations: + print(f"{obs.get('observation_type', '?')[:7]:<8} {obs.get('source_id', '?')[:39]:<40} {obs.get('merkle_root', '?')[:19]:<20}") + + +def cmd_inspect(args): + """Inspect a specific observation by merkle root.""" + from cascade.observation import ObservationManager + + manager = ObservationManager() + obs = manager.get_observation(args.root) + + if not obs: + if HAS_RICH: + console.print(f"[red]Observation not found:[/] {args.root}") + else: + print(f"Observation not found: {args.root}") + return + + if HAS_RICH: + tree = Tree(f"[bold magenta]{args.root}[/]") + + for key, value in obs.items(): + if isinstance(value, dict): + branch = tree.add(f"[cyan]{key}[/]") + for k, v in value.items(): + branch.add(f"[dim]{k}:[/] {v}") + elif isinstance(value, list): + branch = tree.add(f"[cyan]{key}[/] ({len(value)} items)") + for item in value[:5]: + branch.add(str(item)[:60]) + if len(value) > 5: + branch.add(f"[dim]... and {len(value) - 5} more[/]") + else: + tree.add(f"[cyan]{key}:[/] {value}") + + console.print(Panel(tree, title="Observation Details", border_style="magenta")) + else: + print(json.dumps(obs, indent=2, default=str)) + + +def cmd_chains(args): + """List all chains in the lattice.""" + from cascade.viz.lattice_gateway import load_lattice_data + + data = load_lattice_data() + chains = data.get('chains', []) + + if HAS_RICH: + table = Table(title="Lattice Chains", box=box.ROUNDED) + table.add_column("Name", style="cyan") + table.add_column("Merkle Root", style="magenta") + table.add_column("Records", justify="right") + table.add_column("CID", style="dim") + + for chain in chains: + name = chain.get('name', '?') + root = chain.get('merkle_root', '?')[:16] + records = len(chain.get('records', {})) + cid = chain.get('cid', 'Not pinned') + if cid and cid != 'Not pinned': + cid = cid[:20] + '...' + + style = "bold green" if name == 'genesis' else None + table.add_row(name, root, str(records), cid, style=style) + + console.print(table) + console.print(f"\n[dim]Genesis: {data.get('genesis_root', 'N/A')}[/]") + else: + print(f"Chains in lattice: {len(chains)}") + for chain in chains: + print(f" {chain.get('name')}: {chain.get('merkle_root', '?')[:16]} ({len(chain.get('records', {}))} records)") + + +def cmd_pin(args): + """Pin observation to IPFS.""" + from cascade.observation import ObservationManager + + manager = ObservationManager() + obs = manager.get_observation(args.root) + + if not obs: + if HAS_RICH: + console.print(f"[red]Observation not found:[/] {args.root}") + else: + print(f"Observation not found: {args.root}") + return + + if HAS_RICH: + with console.status("[cyan]Pinning to IPFS...[/]"): + cid = manager.pin_to_ipfs(obs) + + if cid: + console.print(f"[green]✓ Pinned to IPFS[/]") + console.print(f" CID: [magenta]{cid}[/]") + console.print(f" URL: https://storacha.link/ipfs/{cid}") + else: + console.print("[red]✗ Failed to pin[/]") + else: + print(f"Pinning {args.root}...") + cid = manager.pin_to_ipfs(obs) + if cid: + print(f"✓ Pinned: {cid}") + else: + print("✗ Failed") + + +def cmd_export(args): + """Export lattice or chain to file.""" + from cascade.viz.lattice_gateway import load_lattice_data + + data = load_lattice_data() + + if args.chain: + chains = [c for c in data.get('chains', []) if c['name'] == args.chain] + if not chains: + msg = f"Chain not found: {args.chain}" + console.print(f"[red]{msg}[/]") if HAS_RICH else print(msg) + return + export_data = chains[0] + else: + export_data = data + + output = Path(args.output) + output.write_text(json.dumps(export_data, indent=2, default=str)) + + msg = f"✓ Exported to {output}" + console.print(f"[green]{msg}[/]") if HAS_RICH else print(msg) + + +def cmd_watch(args): + """Watch live observations in real-time.""" + from cascade.observation import ObservationManager + import time + + manager = ObservationManager() + last_count = 0 + + if HAS_RICH: + console.print("[cyan]Watching for observations... (Ctrl+C to stop)[/]\n") + else: + print("Watching... (Ctrl+C to stop)") + + try: + while True: + stats = manager.get_stats() + current = stats['total_observations'] + + if current > last_count: + new_obs = manager.list_observations(limit=current - last_count) + for obs in reversed(new_obs): + if HAS_RICH: + console.print( + f"[green]●[/] [{datetime.now().strftime('%H:%M:%S')}] " + f"[cyan]{obs.get('observation_type', '?')}[/] " + f"[white]{obs.get('source_id', '?')[:40]}[/] " + f"[magenta]{obs.get('merkle_root', '?')[:16]}[/]" + ) + else: + print(f"● {obs.get('observation_type', '?')} {obs.get('merkle_root', '?')[:16]}") + last_count = current + + time.sleep(1) + except KeyboardInterrupt: + msg = "\nStopped watching." + console.print(f"[yellow]{msg}[/]") if HAS_RICH else print(msg) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# MODEL COMMANDS +# ═══════════════════════════════════════════════════════════════════════════════ + +def cmd_observe(args): + """Manually observe a model interaction.""" + from cascade import observe + + result = observe( + model_id=args.model, + input_data=args.input, + output_data=args.output, + observation_type='model', + ) + + if HAS_RICH: + console.print(f"[green]✓ Observed[/]") + console.print(f" Merkle Root: [magenta]{result.get('merkle_root', 'N/A')}[/]") + else: + print(f"Observed: {result.get('merkle_root', 'N/A')}") + + +def cmd_fingerprint(args): + """Generate model fingerprint.""" + try: + from cascade.forensics.fingerprints import ModelFingerprinter + + if HAS_RICH: + with console.status(f"[cyan]Fingerprinting {args.model}...[/]"): + fp = ModelFingerprinter() + result = fp.fingerprint(args.model) + + if result: + table = Table(title=f"Fingerprint: {args.model}", box=box.ROUNDED) + table.add_column("Property", style="cyan") + table.add_column("Value", style="white") + + for key, value in result.items(): + if isinstance(value, dict): + value = json.dumps(value)[:50] + '...' + table.add_row(str(key), str(value)[:60]) + + console.print(table) + else: + console.print("[yellow]Could not fingerprint model[/]") + else: + fp = ModelFingerprinter() + result = fp.fingerprint(args.model) + print(json.dumps(result, indent=2, default=str)) + except Exception as e: + msg = f"Error: {e}" + console.print(f"[red]{msg}[/]") if HAS_RICH else print(msg) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# DATA COMMANDS +# ═══════════════════════════════════════════════════════════════════════════════ + +def cmd_entities(args): + """Run entity resolution on a file.""" + try: + from cascade.data.entities import EntityResolver + + if HAS_RICH: + with console.status(f"[cyan]Resolving entities in {args.file}...[/]"): + resolver = EntityResolver() + result = resolver.resolve_file(args.file) + + if result: + console.print(f"[green]✓ Found {len(result)} entities[/]") + + table = Table(box=box.SIMPLE) + table.add_column("Entity", style="cyan") + table.add_column("Type", style="magenta") + table.add_column("Count", justify="right") + + for entity in result[:20]: + table.add_row( + str(entity.get('name', '?'))[:30], + entity.get('type', '?'), + str(entity.get('count', 1)) + ) + + console.print(table) + if len(result) > 20: + console.print(f"[dim]... and {len(result) - 20} more[/]") + else: + resolver = EntityResolver() + result = resolver.resolve_file(args.file) + print(f"Found {len(result)} entities") + except Exception as e: + msg = f"Error: {e}" + console.print(f"[red]{msg}[/]") if HAS_RICH else print(msg) + + +def cmd_pii(args): + """Scan for PII in a file.""" + try: + from cascade.data.pii import PIIScanner + + if HAS_RICH: + with console.status(f"[cyan]Scanning {args.file} for PII...[/]"): + scanner = PIIScanner() + results = scanner.scan_file(args.file) + + if results: + console.print(f"[yellow]⚠ Found {len(results)} potential PII instances[/]") + + table = Table(box=box.ROUNDED) + table.add_column("Type", style="red") + table.add_column("Value", style="yellow") + table.add_column("Location", style="dim") + + for pii in results[:20]: + val = pii.get('value', '?') + table.add_row( + pii.get('type', '?'), + val[:30] + '...' if len(val) > 30 else val, + str(pii.get('location', '?')) + ) + + console.print(table) + else: + console.print("[green]✓ No PII detected[/]") + else: + scanner = PIIScanner() + results = scanner.scan_file(args.file) + print(f"Found {len(results)} PII instances") + except Exception as e: + msg = f"Error: {e}" + console.print(f"[red]{msg}[/]") if HAS_RICH else print(msg) + + +def cmd_provenance(args): + """Show data provenance for a file/dataset.""" + try: + from cascade.data.provenance import DataProvenance + + if HAS_RICH: + with console.status(f"[cyan]Analyzing provenance...[/]"): + prov = DataProvenance() + result = prov.analyze(args.path) + + if result: + tree = Tree(f"[bold cyan]{args.path}[/]") + + if 'hash' in result: + tree.add(f"[magenta]Hash:[/] {result['hash']}") + if 'sources' in result: + sources = tree.add("[cyan]Sources[/]") + for src in result['sources']: + sources.add(str(src)) + if 'transformations' in result: + transforms = tree.add("[cyan]Transformations[/]") + for t in result['transformations']: + transforms.add(str(t)) + + console.print(Panel(tree, title="Data Provenance", border_style="cyan")) + else: + prov = DataProvenance() + result = prov.analyze(args.path) + print(json.dumps(result, indent=2, default=str)) + except Exception as e: + msg = f"Error: {e}" + console.print(f"[red]{msg}[/]") if HAS_RICH else print(msg) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# SYSTEM COMMANDS +# ═══════════════════════════════════════════════════════════════════════════════ + +def cmd_ingest(args): + """Ingest logs/files into the lattice.""" + try: + from cascade.system.repo_ingester import RepoIngester + + if HAS_RICH: + with console.status(f"[cyan]Ingesting {args.path}...[/]"): + ingester = RepoIngester() + result = ingester.ingest(args.path) + + console.print(f"[green]✓ Ingested[/]") + console.print(f" Files: {result.get('files', 0)}") + console.print(f" Observations: {result.get('observations', 0)}") + console.print(f" Merkle Root: [magenta]{result.get('merkle_root', 'N/A')}[/]") + else: + ingester = RepoIngester() + result = ingester.ingest(args.path) + print(f"Ingested: {result}") + except Exception as e: + msg = f"Error: {e}" + console.print(f"[red]{msg}[/]") if HAS_RICH else print(msg) + + +def cmd_analyze(args): + """Analyze a log file or folder.""" + try: + from cascade.system.omnidirectional_analyzer import OmnidirectionalAnalyzer + + if HAS_RICH: + with console.status(f"[cyan]Analyzing {args.path}...[/]"): + analyzer = OmnidirectionalAnalyzer() + result = analyzer.analyze(args.path) + + if result: + console.print(Panel( + Syntax(json.dumps(result, indent=2, default=str), "json"), + title="Analysis Result", + border_style="cyan" + )) + else: + analyzer = OmnidirectionalAnalyzer() + result = analyzer.analyze(args.path) + print(json.dumps(result, indent=2, default=str)) + except Exception as e: + msg = f"Error: {e}" + console.print(f"[red]{msg}[/]") if HAS_RICH else print(msg) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# PROXY & INIT +# ═══════════════════════════════════════════════════════════════════════════════ + +def cmd_proxy(args): + """Start the CASCADE proxy server.""" + if HAS_RICH: + console.print(Panel( + f"""[cyan]CASCADE Proxy Server[/] + +Listening on [bold]{args.host}:{args.port}[/] + +Set these environment variables in your app: +[green] + OPENAI_BASE_URL=http://localhost:{args.port}/v1 + ANTHROPIC_BASE_URL=http://localhost:{args.port}/anthropic +[/] +Press Ctrl+C to stop.""", + title="🌐 Proxy Mode", + border_style="cyan", + )) + else: + print(f"CASCADE Proxy on {args.host}:{args.port}") + + from cascade.proxy import run_proxy + run_proxy(host=args.host, port=args.port, verbose=not args.quiet) + + +def cmd_init(args): + """Show initialization instructions.""" + if HAS_RICH: + md = """ +# CASCADE Setup + +## Option 1: Auto-Patch (Python) +```python +import cascade +cascade.init() + +# Now every call emits a receipt +from openai import OpenAI +client = OpenAI() +client.chat.completions.create(...) # ← automatically observed +``` + +## Option 2: Proxy Mode (Any Language) +```bash +cascade proxy --port 7777 +``` +Then set environment variables: +```bash +export OPENAI_BASE_URL=http://localhost:7777/v1 +export ANTHROPIC_BASE_URL=http://localhost:7777/anthropic +``` + +## Option 3: Manual Observation +```python +from cascade import observe +observe(model_id="my-model", input_data="prompt", output_data="response") +``` + +--- +**Genesis Root:** `89f940c1a4b7aa65` +""" + console.print(Panel(Markdown(md), title="[bold cyan]CASCADE[/]", border_style="cyan")) + else: + print(""" +CASCADE - Universal AI Provenance Layer + +OPTION 1: Auto-Patch (Python) + import cascade + cascade.init() + +OPTION 2: Proxy Mode (Any Language) + cascade proxy + export OPENAI_BASE_URL=http://localhost:7777/v1 + +OPTION 3: Manual + from cascade import observe + observe(model_id="...", input_data="...", output_data="...") +""") + + +def cmd_version(args): + """Show version.""" + try: + from cascade import __version__ + version = __version__ + except: + version = "0.1.1" + + if HAS_RICH: + console.print(f"[cyan]cascade-ai[/] [bold]{version}[/]") + console.print(f"[dim]Genesis: 89f940c1a4b7aa65[/]") + else: + print(f"cascade-ai {version}") + + +# ═══════════════════════════════════════════════════════════════════════════════ +# HOLD COMMANDS - Inference-Level Halt Protocol +# ═══════════════════════════════════════════════════════════════════════════════ + +def cmd_hold_status(args): + """Show HOLD system status.""" + try: + from cascade.hold import Hold + hold = Hold.get() + + if HAS_RICH: + from rich.table import Table + + table = Table(title="🛑 HOLD System Status", box=box.SIMPLE) + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + + table.add_row("Hold Count", str(hold._hold_count)) + table.add_row("Override Count", str(hold._override_count)) + table.add_row("Timeout", f"{hold.timeout}s") + table.add_row("Auto Accept", str(hold.auto_accept)) + table.add_row("Listeners", str(len(hold._listeners))) + table.add_row("Last Merkle", hold._last_merkle or "None") + table.add_row("Current Hold", "Active" if hold._current_hold else "None") + + console.print(table) + else: + print(f"HOLD Count: {hold._hold_count}") + print(f"Override Count: {hold._override_count}") + print(f"Timeout: {hold.timeout}s") + print(f"Listeners: {len(hold._listeners)}") + except Exception as e: + if HAS_RICH: + console.print(f"[red]Error: {e}[/]") + else: + print(f"Error: {e}") + + +def cmd_hold_info(args): + """Show HOLD usage information.""" + info = """ +🛑 HOLD - Inference-Level Halt Protocol + +HOLD pauses AI inference so humans can observe and intervene. + +USAGE IN YOUR CODE: + from cascade.hold import Hold + + hold = Hold.get() + + # In your inference loop: + probs = model.predict(observation) + + resolution = hold.yield_point( + action_probs=probs, + value=value_estimate, + observation=obs, + brain_id="my_model", + # Optional informational wealth: + action_labels=["up", "down", "left", "right"], + latent=model.latent, + attention=model.attention, + features=model.features, + imagination=model.imagine(), + ) + + action = resolution.action # Final action (AI or override) + was_override = resolution.was_override # True if human intervened + +REGISTERING LISTENERS: + def my_handler(hold_point): + print(f"HOLD: {hold_point.action_probs}") + # Send to UI, game engine, logger, etc. + + hold.register_listener(my_handler) + +RESOLVING HOLDS: + hold.resolve(action=3, source="human") # Override with action 3 + hold.accept() # Accept AI's choice +""" + if HAS_RICH: + console.print(Panel(info, title="[bold red]HOLD[/]", border_style="red")) + else: + print(info) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# MAIN +# ═══════════════════════════════════════════════════════════════════════════════ + +def main(): + """Main CLI entry point.""" + parser = argparse.ArgumentParser( + prog="cascade", + description="CASCADE - Universal AI Provenance Layer", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + cascade stats Show lattice statistics + cascade list -n 20 List recent observations + cascade chains List all chains + cascade inspect Inspect an observation + cascade watch Live observation feed + cascade proxy Start proxy server + cascade fingerprint Fingerprint a model + cascade pii Scan file for PII + cascade ingest Ingest logs/files + """ + ) + parser.add_argument("--version", "-v", action="store_true", help="Show version") + + subparsers = parser.add_subparsers(dest="command", help="Commands") + + # ─── Lattice commands ─── + subparsers.add_parser("stats", help="Show lattice statistics").set_defaults(func=cmd_stats) + subparsers.add_parser("chains", help="List all chains").set_defaults(func=cmd_chains) + subparsers.add_parser("init", help="Show setup instructions").set_defaults(func=cmd_init) + subparsers.add_parser("watch", help="Watch live observations").set_defaults(func=cmd_watch) + + list_p = subparsers.add_parser("list", help="List recent observations") + list_p.add_argument("--limit", "-n", type=int, default=10, help="Number to show") + list_p.set_defaults(func=cmd_list) + + inspect_p = subparsers.add_parser("inspect", help="Inspect an observation") + inspect_p.add_argument("root", help="Merkle root to inspect") + inspect_p.set_defaults(func=cmd_inspect) + + pin_p = subparsers.add_parser("pin", help="Pin observation to IPFS") + pin_p.add_argument("root", help="Merkle root to pin") + pin_p.set_defaults(func=cmd_pin) + + export_p = subparsers.add_parser("export", help="Export lattice/chain to JSON") + export_p.add_argument("--chain", "-c", help="Export specific chain") + export_p.add_argument("--output", "-o", default="cascade_export.json", help="Output file") + export_p.set_defaults(func=cmd_export) + + # ─── Model commands ─── + observe_p = subparsers.add_parser("observe", help="Manual observation") + observe_p.add_argument("--model", "-m", required=True, help="Model ID") + observe_p.add_argument("--input", "-i", required=True, help="Input data") + observe_p.add_argument("--output", "-o", required=True, help="Output data") + observe_p.set_defaults(func=cmd_observe) + + fp_p = subparsers.add_parser("fingerprint", help="Fingerprint a model") + fp_p.add_argument("model", help="Model name/path") + fp_p.set_defaults(func=cmd_fingerprint) + + # ─── Data commands ─── + entities_p = subparsers.add_parser("entities", help="Entity resolution") + entities_p.add_argument("file", help="File to analyze") + entities_p.set_defaults(func=cmd_entities) + + pii_p = subparsers.add_parser("pii", help="Scan for PII") + pii_p.add_argument("file", help="File to scan") + pii_p.set_defaults(func=cmd_pii) + + prov_p = subparsers.add_parser("provenance", help="Data provenance") + prov_p.add_argument("path", help="File or dataset path") + prov_p.set_defaults(func=cmd_provenance) + + # ─── System commands ─── + ingest_p = subparsers.add_parser("ingest", help="Ingest logs/files") + ingest_p.add_argument("path", help="Path to ingest") + ingest_p.set_defaults(func=cmd_ingest) + + analyze_p = subparsers.add_parser("analyze", help="Analyze logs/files") + analyze_p.add_argument("path", help="Path to analyze") + analyze_p.set_defaults(func=cmd_analyze) + + # ─── Proxy ─── + proxy_p = subparsers.add_parser("proxy", help="Start proxy server") + proxy_p.add_argument("--host", default="0.0.0.0", help="Host to bind") + proxy_p.add_argument("--port", "-p", type=int, default=7777, help="Port") + proxy_p.add_argument("--quiet", "-q", action="store_true", help="Quiet mode") + proxy_p.set_defaults(func=cmd_proxy) + + # ─── HOLD - Inference-Level Halt Protocol ─── + hold_p = subparsers.add_parser("hold", help="Show HOLD usage and API info") + hold_p.set_defaults(func=cmd_hold_info) + + hold_status_p = subparsers.add_parser("hold-status", help="Show HOLD system status") + hold_status_p.set_defaults(func=cmd_hold_status) + + # Parse + args = parser.parse_args() + + if args.version: + cmd_version(args) + return + + if not args.command: + if HAS_RICH: + console.print(Panel( + """[cyan]CASCADE[/] - Universal AI Provenance Layer + +[bold]Lattice Commands:[/] + [green]stats[/] Show lattice statistics + [green]chains[/] List all chains + [green]list[/] List recent observations + [green]inspect[/] Inspect an observation + [green]watch[/] Live observation feed + [green]pin[/] Pin to IPFS + [green]export[/] Export to JSON + +[bold]Model Commands:[/] + [green]observe[/] Manual observation + [green]fingerprint[/] Fingerprint a model + +[bold]Data Commands:[/] + [green]entities[/] Entity resolution + [green]pii[/] PII scanner + [green]provenance[/] Data provenance + +[bold]System Commands:[/] + [green]ingest[/] Ingest files/logs + [green]analyze[/] Analyze files + +[bold]HOLD (Inference Halt):[/] + [green]hold[/] Show HOLD usage and API info + [green]hold-status[/] Show HOLD system status + +[bold]Other:[/] + [green]proxy[/] Start proxy server + [green]init[/] Setup instructions + +Use [cyan]cascade --help[/] for details.""", + title="[bold magenta]🌀 CASCADE[/]", + subtitle="[dim]pip install cascade-ai[/]", + border_style="magenta", + )) + else: + parser.print_help() + return + + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/cascade/core/__init__.py b/cascade/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28619e26b5b8ad42e9e65ca1a491b82d2a81c130 --- /dev/null +++ b/cascade/core/__init__.py @@ -0,0 +1,13 @@ +"""Cascade Core module - fundamental data structures and algorithms.""" + +from cascade.core.event import Event, CausationLink, CausationChain +from cascade.core.graph import CausationGraph +from cascade.core.adapter import SymbioticAdapter + +__all__ = [ + "Event", + "CausationLink", + "CausationChain", + "CausationGraph", + "SymbioticAdapter", +] diff --git a/cascade/core/adapter.py b/cascade/core/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..4e685a9dcc232c06da4c2c35231683760d114636 --- /dev/null +++ b/cascade/core/adapter.py @@ -0,0 +1,470 @@ +""" +Cascade Core - Symbiotic Adapter. + +The heart of Cascade's system-agnostic design. The adapter uses Kleene fixed-point +convergence to interpret ANY signal format and convert it to Events. + +"It doesn't hook into your system — it becomes part of it." +""" + +import time +import json +import re +from typing import Any, Dict, List, Optional, Callable, Type +from dataclasses import dataclass + +from cascade.core.event import Event + + +@dataclass +class SignalPattern: + """A learned pattern for interpreting signals.""" + pattern_type: str # 'dict', 'string', 'tensor', 'protobuf', etc. + component: str + event_type: str + extractor: Optional[Callable[[Any], Dict[str, Any]]] = None + confidence: float = 0.0 + match_count: int = 0 + + +class SymbioticAdapter: + """ + Self-interpreting adapter that converges to any signal format. + + The adapter observes signals from the host system and learns how to + interpret them through fixed-point iteration. It starts with naive + interpretations and refines them until stable. + + This is the key to Cascade's system-agnostic design: + - No framework-specific hooks required + - No configuration needed + - Feed it ANY signal format, it adapts + + Example: + >>> adapter = SymbioticAdapter() + >>> + >>> # Feed it different signal formats + >>> adapter.interpret({"loss": 0.5, "epoch": 10}) + >>> adapter.interpret("2024-01-01 12:00:00 ERROR training failed") + >>> adapter.interpret(torch.tensor([0.1, 0.2, 0.3])) + >>> + >>> # It learns patterns and gets better at interpretation + >>> print(adapter.learned_patterns) + """ + + def __init__(self): + """Initialize the symbiotic adapter.""" + self._patterns: List[SignalPattern] = [] + self._signal_count = 0 + self._interpretation_cache: Dict[str, SignalPattern] = {} + + # Built-in interpreters for common formats + self._builtin_interpreters = { + dict: self._interpret_dict, + str: self._interpret_string, + list: self._interpret_list, + } + + # Regex patterns for log line parsing + self._log_patterns = [ + # ISO timestamp with level: "2024-01-01 12:00:00 ERROR message" + re.compile(r'^(\d{4}-\d{2}-\d{2}[T\s]\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+(\w+)\s+(.*)$'), + # Simple timestamp: "12:00:00.123 component message" + re.compile(r'^(\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+(\w+)\s+(.*)$'), + # Pipe-delimited: "timestamp|level|component|key:value" + re.compile(r'^([^|]+)\|(\w+)\|(\w+)\|(.*)$'), + ] + + # Metric extraction patterns - ONLY extract real training metrics + # Be strict to avoid extracting garbage from config lines + self._metric_patterns = [ + # Standard training metrics with = or : + re.compile(r'\b(loss|val_loss|train_loss|accuracy|acc|val_acc|lr|learning_rate|epoch|step|iter|iteration|mfu|tokens_per_sec|samples_per_sec|grad_norm|perplexity|ppl)[=:]\s*([+-]?\d+\.?\d*(?:e[+-]?\d+)?)', re.I), + # "iter X: loss=Y" format from nanoGPT + re.compile(r'iter\s+(\d+).*loss[=:]?\s*([+-]?\d+\.?\d*)', re.I), + # "step X loss Y" format + re.compile(r'step\s+(\d+).*loss\s*[=:]?\s*([+-]?\d+\.?\d*)', re.I), + ] + + def interpret(self, signal: Any) -> Event: + """ + Interpret any signal into a Cascade Event. + + Uses Kleene fixed-point iteration to converge on the best interpretation. + + Args: + signal: Any signal from the host system + + Returns: + Event: The interpreted event + """ + self._signal_count += 1 + + # Get signal type + signal_type = type(signal) + + # Try cached pattern first + cache_key = self._get_cache_key(signal) + if cache_key in self._interpretation_cache: + pattern = self._interpretation_cache[cache_key] + pattern.match_count += 1 + return self._apply_pattern(signal, pattern) + + # Try built-in interpreter + if signal_type in self._builtin_interpreters: + event = self._builtin_interpreters[signal_type](signal) + self._learn_pattern(signal, event) + return event + + # Try tensor-like objects (duck typing) + if hasattr(signal, 'numpy') or hasattr(signal, 'detach'): + event = self._interpret_tensor(signal) + self._learn_pattern(signal, event) + return event + + # Try protobuf-like objects + if hasattr(signal, 'SerializeToString'): + event = self._interpret_protobuf(signal) + self._learn_pattern(signal, event) + return event + + # Fallback: convert to string and interpret + event = self._interpret_string(str(signal)) + return event + + def _interpret_dict(self, signal: Dict[str, Any]) -> Event: + """Interpret a dictionary signal.""" + # Extract common fields + timestamp = signal.get('timestamp', signal.get('time', time.time())) + if isinstance(timestamp, str): + try: + from datetime import datetime + timestamp = datetime.fromisoformat(timestamp).timestamp() + except: + timestamp = time.time() + + component = signal.get('component', signal.get('source', 'unknown')) + event_type = signal.get('event_type', signal.get('type', 'state_change')) + + # Everything else goes in data + reserved = {'timestamp', 'time', 'component', 'source', 'event_type', 'type'} + data = {k: v for k, v in signal.items() if k not in reserved} + + return Event( + timestamp=timestamp, + component=component, + event_type=event_type, + data=data, + source_signal=signal, + ) + + def _interpret_string(self, signal: str) -> Event: + """Interpret a string signal (log line, message, etc.).""" + signal = signal.strip() + + # Try each log pattern + for pattern in self._log_patterns: + match = pattern.match(signal) + if match: + groups = match.groups() + if len(groups) >= 3: + timestamp_str, level_or_component, rest = groups[0], groups[1], groups[-1] + + # Parse timestamp + try: + from datetime import datetime + timestamp = datetime.fromisoformat(timestamp_str.replace(' ', 'T')).timestamp() + except: + timestamp = time.time() + + # Extract metrics from the rest + data = self._extract_metrics(rest) + data['raw_message'] = rest + + # Determine event type from keywords + event_type = self._infer_event_type(signal) + + return Event( + timestamp=timestamp, + component=level_or_component.lower(), + event_type=event_type, + data=data, + source_signal=signal, + ) + + # Fallback: extract what we can with smarter component detection + data = self._extract_metrics(signal) + data['raw_message'] = signal + + # Infer component from content + component = self._infer_component(signal) + + return Event( + timestamp=time.time(), + component=component, + event_type=self._infer_event_type(signal), + data=data, + source_signal=signal, + ) + + def _interpret_list(self, signal: List[Any]) -> Event: + """Interpret a list signal.""" + # Convert to dict with indices + data = {f'item_{i}': v for i, v in enumerate(signal)} + data['length'] = len(signal) + + # Check if it looks like numeric data + if all(isinstance(x, (int, float)) for x in signal): + data['mean'] = sum(signal) / len(signal) if signal else 0 + data['min'] = min(signal) if signal else 0 + data['max'] = max(signal) if signal else 0 + + return Event( + timestamp=time.time(), + component='data', + event_type='list_signal', + data=data, + source_signal=signal, + ) + + def _interpret_tensor(self, signal: Any) -> Event: + """Interpret a tensor-like signal (PyTorch, NumPy, etc.).""" + # Try to get numpy array + try: + if hasattr(signal, 'detach'): + arr = signal.detach().cpu().numpy() + elif hasattr(signal, 'numpy'): + arr = signal.numpy() + else: + arr = signal + + data = { + 'shape': list(arr.shape) if hasattr(arr, 'shape') else [], + 'dtype': str(arr.dtype) if hasattr(arr, 'dtype') else 'unknown', + 'mean': float(arr.mean()) if hasattr(arr, 'mean') else 0, + 'std': float(arr.std()) if hasattr(arr, 'std') else 0, + 'min': float(arr.min()) if hasattr(arr, 'min') else 0, + 'max': float(arr.max()) if hasattr(arr, 'max') else 0, + } + + # Check for NaN/Inf (common in gradient explosions) + if hasattr(arr, 'isnan'): + data['has_nan'] = bool(arr.isnan().any()) + if hasattr(arr, 'isinf'): + data['has_inf'] = bool(arr.isinf().any()) + + except Exception as e: + data = {'error': str(e), 'type': str(type(signal))} + + return Event( + timestamp=time.time(), + component='tensor', + event_type='tensor_signal', + data=data, + source_signal=None, # Don't store tensor to save memory + ) + + def _interpret_protobuf(self, signal: Any) -> Event: + """Interpret a protobuf-like signal.""" + try: + # Try to convert to dict + if hasattr(signal, 'DESCRIPTOR'): + from google.protobuf.json_format import MessageToDict + data = MessageToDict(signal) + else: + data = {'raw': str(signal)} + except: + data = {'raw': str(signal)} + + return Event( + timestamp=time.time(), + component='protobuf', + event_type='protobuf_signal', + data=data, + source_signal=None, + ) + + def _extract_metrics(self, text: str) -> Dict[str, Any]: + """Extract numeric metrics from text - STRICT, only real training metrics.""" + metrics = {} + + # nanoGPT format: "iter 0: loss=4.2176, time 46.76ms, mfu 0.62%" + nano_match = re.search(r'iter\s+(\d+).*loss[=:]?\s*([\d.]+)', text, re.I) + if nano_match: + metrics['iter'] = int(nano_match.group(1)) + metrics['loss'] = float(nano_match.group(2)) + + # Diffusers/tqdm format: "step_loss=0.1234" or "step_loss: 0.1234" + step_loss_match = re.search(r'step_loss[=:]\s*([\d.e+-]+)', text, re.I) + if step_loss_match: + metrics['loss'] = float(step_loss_match.group(1)) + + # train_loss format from accelerator.log + train_loss_match = re.search(r'train_loss[=:]\s*([\d.e+-]+)', text, re.I) + if train_loss_match: + metrics['loss'] = float(train_loss_match.group(1)) + + # tqdm progress format: " 5%|█ | 5/100 [00:30<09:30, step_loss=0.234, lr=1e-5]" + tqdm_match = re.search(r'(\d+)%\|.*\|\s*(\d+)/(\d+)', text) + if tqdm_match: + metrics['progress_pct'] = int(tqdm_match.group(1)) + metrics['step'] = int(tqdm_match.group(2)) + metrics['total_steps'] = int(tqdm_match.group(3)) + + # Generic loss patterns + generic_loss = re.search(r'\bloss[=:]\s*([\d.e+-]+)', text, re.I) + if generic_loss and 'loss' not in metrics: + metrics['loss'] = float(generic_loss.group(1)) + + # mfu extraction + mfu_match = re.search(r'mfu\s*[=:]?\s*([\d.]+)%?', text, re.I) + if mfu_match: + metrics['mfu'] = float(mfu_match.group(1)) + + # time extraction (ms) + time_match = re.search(r'time\s*[=:]?\s*([\d.]+)\s*ms', text, re.I) + if time_match: + metrics['time_ms'] = float(time_match.group(1)) + + # learning rate - multiple formats + lr_match = re.search(r'\b(?:lr|learning_rate)\s*[=:]\s*([\d.e+-]+)', text, re.I) + if lr_match: + metrics['lr'] = float(lr_match.group(1)) + + # epoch/step for other frameworks + epoch_match = re.search(r'\bepoch\s*[=:]\s*(\d+)', text, re.I) + if epoch_match: + metrics['epoch'] = int(epoch_match.group(1)) + + step_match = re.search(r'\bstep\s*[=:]\s*(\d+)', text, re.I) + if step_match and 'step' not in metrics: + metrics['step'] = int(step_match.group(1)) + + # global_step from diffusers + global_step_match = re.search(r'global_step[=:]\s*(\d+)', text, re.I) + if global_step_match: + metrics['step'] = int(global_step_match.group(1)) + + return metrics + + def _infer_event_type(self, text: str) -> str: + """Infer event type from text content.""" + text_lower = text.lower() + + # Training iteration logs (highest priority) + if re.search(r'iter\s+\d+.*loss', text_lower): + return 'training_step' + if re.search(r'step\s+\d+.*loss', text_lower): + return 'training_step' + + if any(kw in text_lower for kw in ['error', 'exception', 'failed', 'crash']): + return 'error' + if any(kw in text_lower for kw in ['warning', 'warn']): + return 'warning' + if any(kw in text_lower for kw in ['gradient', 'backward']): + return 'training' + if 'loss' in text_lower and 'val' in text_lower: + return 'validation' + if any(kw in text_lower for kw in ['inference', 'predict', 'forward']): + return 'inference' + if any(kw in text_lower for kw in ['epoch', 'step', 'iteration', 'iter']): + return 'progress' + if any(kw in text_lower for kw in ['nan', 'inf', 'explode', 'overflow']): + return 'anomaly' + if any(kw in text_lower for kw in ['save', 'checkpoint', 'load', 'saving']): + return 'checkpoint' + if any(kw in text_lower for kw in ['config', 'setting', 'parameter', 'device', 'gpu', 'cuda']): + return 'config' + if any(kw in text_lower for kw in ['initializ', 'loading model', 'compiling']): + return 'init' + + return 'state_change' + + def _infer_component(self, text: str) -> str: + """Infer component from text content - NO MORE 'unknown'.""" + text_lower = text.lower() + + # Training/optimizer related + if any(kw in text_lower for kw in ['iter', 'step', 'epoch', 'batch']): + return 'trainer' + if any(kw in text_lower for kw in ['loss', 'backward', 'gradient']): + return 'loss' + if any(kw in text_lower for kw in ['optim', 'adam', 'sgd', 'lr', 'learning']): + return 'optimizer' + if any(kw in text_lower for kw in ['model', 'layer', 'param', 'weight']): + return 'model' + if any(kw in text_lower for kw in ['data', 'batch', 'loader', 'dataset']): + return 'data' + if any(kw in text_lower for kw in ['cuda', 'gpu', 'device', 'memory']): + return 'device' + if any(kw in text_lower for kw in ['checkpoint', 'save', 'load']): + return 'checkpoint' + if any(kw in text_lower for kw in ['config', 'setting', 'override']): + return 'config' + if any(kw in text_lower for kw in ['eval', 'valid', 'test']): + return 'evaluator' + if any(kw in text_lower for kw in ['token', 'vocab', 'embed']): + return 'tokenizer' + + return 'system' # Generic fallback, not "unknown" + + def _get_cache_key(self, signal: Any) -> str: + """Generate a cache key for a signal's structure.""" + if isinstance(signal, dict): + # Key based on dict keys + return f"dict:{':'.join(sorted(signal.keys()))}" + elif isinstance(signal, str): + # Key based on first word + first_word = signal.split()[0] if signal.split() else '' + return f"str:{first_word[:20]}" + else: + return f"type:{type(signal).__name__}" + + def _learn_pattern(self, signal: Any, event: Event) -> None: + """Learn a pattern from a successful interpretation.""" + cache_key = self._get_cache_key(signal) + pattern = SignalPattern( + pattern_type=type(signal).__name__, + component=event.component, + event_type=event.event_type, + confidence=0.5, + match_count=1, + ) + self._interpretation_cache[cache_key] = pattern + self._patterns.append(pattern) + + def _apply_pattern(self, signal: Any, pattern: SignalPattern) -> Event: + """Apply a learned pattern to interpret a signal.""" + # Re-interpret with learned hints - use direct interpreters to avoid recursion + if isinstance(signal, dict): + event = self._interpret_dict(signal) + # Apply learned component/type if more confident + if pattern.confidence > 0.7: + return Event( + timestamp=event.timestamp, + component=pattern.component, + event_type=pattern.event_type, + data=event.data, + source_signal=signal, + ) + return event + elif isinstance(signal, str): + return self._interpret_string(signal) + elif isinstance(signal, list): + return self._interpret_list(signal) + else: + # Fallback: interpret as string without recursion + return self._interpret_string(str(signal)) + + @property + def learned_patterns(self) -> List[SignalPattern]: + """Get all learned signal patterns.""" + return sorted(self._patterns, key=lambda p: p.match_count, reverse=True) + + @property + def signal_count(self) -> int: + """Total number of signals interpreted.""" + return self._signal_count + + def __repr__(self) -> str: + return f"" diff --git a/cascade/core/event.py b/cascade/core/event.py new file mode 100644 index 0000000000000000000000000000000000000000..1d25522b31822612a67d1c1dca15215d33e42493 --- /dev/null +++ b/cascade/core/event.py @@ -0,0 +1,177 @@ +""" +Cascade Core - Event and CausationLink primitives. + +These are the fundamental data structures that represent causation. +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Any, Optional +from datetime import datetime +import time +import uuid + + +def _generate_event_id() -> str: + """Generate a unique event ID with timestamp prefix for ordering.""" + timestamp = int(time.time() * 1000000) + unique = uuid.uuid4().hex[:8] + return f"evt_{timestamp}_{unique}" + + +@dataclass +class Event: + """ + A discrete event in the causation graph. + + Events are the nodes in your causation graph. Each event represents + something that happened in your system at a point in time. + + Attributes: + event_id: Unique identifier (auto-generated if not provided) + timestamp: Unix timestamp when event occurred + component: Which system component generated this event + event_type: Category of event (e.g., 'training', 'inference', 'error') + data: Arbitrary key-value data associated with the event + source_signal: The original signal that created this event (for debugging) + + Example: + >>> event = Event( + ... timestamp=time.time(), + ... component="neural_network", + ... event_type="gradient_explosion", + ... data={"layer": "fc3", "magnitude": 1e12} + ... ) + """ + timestamp: float + component: str + event_type: str + data: Dict[str, Any] = field(default_factory=dict) + event_id: str = field(default_factory=_generate_event_id) + source_signal: Optional[Any] = field(default=None, repr=False) + + def __post_init__(self): + """Ensure timestamp is float.""" + if isinstance(self.timestamp, datetime): + self.timestamp = self.timestamp.timestamp() + + def to_dict(self) -> Dict[str, Any]: + """Serialize event to dictionary.""" + return { + "event_id": self.event_id, + "timestamp": self.timestamp, + "component": self.component, + "event_type": self.event_type, + "data": self.data, + } + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "Event": + """Deserialize event from dictionary.""" + return cls( + event_id=d.get("event_id", _generate_event_id()), + timestamp=d["timestamp"], + component=d["component"], + event_type=d["event_type"], + data=d.get("data", {}), + ) + + def __hash__(self): + return hash(self.event_id) + + def __eq__(self, other): + if isinstance(other, Event): + return self.event_id == other.event_id + return False + + +@dataclass +class CausationLink: + """ + A causal relationship between two events. + + Links are the edges in your causation graph. Each link represents + a cause-effect relationship: event A caused event B. + + Attributes: + from_event: ID of the causing event + to_event: ID of the caused event + causation_type: How the causation was detected + - 'temporal': A happened shortly before B + - 'correlation': A and B metrics moved together + - 'threshold': A crossed a threshold triggering B + - 'direct': Explicit causation declared in code + strength: Confidence in the causal relationship (0.0 to 1.0) + explanation: Human-readable explanation of the link + metrics_involved: Which metrics connect these events + + Example: + >>> link = CausationLink( + ... from_event="evt_123", + ... to_event="evt_456", + ... causation_type="threshold", + ... strength=0.95, + ... explanation="Loss exceeded 10.0, triggering gradient clipping" + ... ) + """ + from_event: str + to_event: str + causation_type: str # 'temporal', 'correlation', 'threshold', 'direct' + strength: float = 1.0 + explanation: str = "" + metrics_involved: List[str] = field(default_factory=list) + + def __post_init__(self): + """Validate strength is in range.""" + self.strength = max(0.0, min(1.0, self.strength)) + + def to_dict(self) -> Dict[str, Any]: + """Serialize link to dictionary.""" + return { + "from_event": self.from_event, + "to_event": self.to_event, + "causation_type": self.causation_type, + "strength": self.strength, + "explanation": self.explanation, + "metrics_involved": self.metrics_involved, + } + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "CausationLink": + """Deserialize link from dictionary.""" + return cls( + from_event=d["from_event"], + to_event=d["to_event"], + causation_type=d["causation_type"], + strength=d.get("strength", 1.0), + explanation=d.get("explanation", ""), + metrics_involved=d.get("metrics_involved", []), + ) + + +@dataclass +class CausationChain: + """ + A chain of causal events from origin to destination. + + Represents a full causal path through the graph. + + Attributes: + events: List of events in causal order + links: List of links connecting the events + total_strength: Combined strength of all links + depth: Number of hops in the chain + narrative: Human-readable story of what happened + """ + events: List[Event] + links: List[CausationLink] + total_strength: float = 1.0 + depth: int = 0 + narrative: str = "" + + def __post_init__(self): + self.depth = len(self.links) + if not self.total_strength and self.links: + # Calculate combined strength + self.total_strength = 1.0 + for link in self.links: + self.total_strength *= link.strength diff --git a/cascade/core/graph.py b/cascade/core/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..3741763eeb0f48b087009750aaba820030b082b3 --- /dev/null +++ b/cascade/core/graph.py @@ -0,0 +1,292 @@ +""" +Cascade Core - Causation Graph Engine. + +The graph stores events and their causal relationships, enabling +bidirectional traversal through time. +""" + +import threading +from typing import Dict, List, Optional, Set, Any, Iterator +from collections import defaultdict +from datetime import datetime + +try: + import networkx as nx + HAS_NETWORKX = True +except ImportError: + HAS_NETWORKX = False + +from cascade.core.event import Event, CausationLink + + +class CausationGraph: + """ + A directed graph of causal relationships between events. + + The graph enables bidirectional traversal: + - Backwards: "What caused this event?" + - Forwards: "What did this event cause?" + + Thread-safe for concurrent event ingestion. + + Example: + >>> graph = CausationGraph() + >>> graph.add_event(event1) + >>> graph.add_event(event2) + >>> graph.add_link(CausationLink( + ... from_event=event1.event_id, + ... to_event=event2.event_id, + ... causation_type="temporal", + ... strength=0.9 + ... )) + >>> + >>> # Find what caused event2 + >>> causes = graph.get_causes(event2.event_id) + """ + + def __init__(self): + """Initialize an empty causation graph.""" + self._lock = threading.RLock() + + # Event storage + self._events: Dict[str, Event] = {} + self._events_by_component: Dict[str, List[str]] = defaultdict(list) + self._events_by_type: Dict[str, List[str]] = defaultdict(list) + self._events_by_time: List[str] = [] # Ordered by timestamp + + # Link storage + self._links: Dict[str, CausationLink] = {} # link_id -> link + self._causes: Dict[str, Set[str]] = defaultdict(set) # event_id -> set of cause event_ids + self._effects: Dict[str, Set[str]] = defaultdict(set) # event_id -> set of effect event_ids + + # NetworkX graph for advanced algorithms (optional) + if HAS_NETWORKX: + self._nx_graph = nx.DiGraph() + else: + self._nx_graph = None + + # Statistics + self._event_count = 0 + self._link_count = 0 + + def add_event(self, event: Event) -> None: + """ + Add an event to the graph. + + Thread-safe. Automatically detects potential causations with recent events. + + Args: + event: The event to add + """ + with self._lock: + if event.event_id in self._events: + return # Already exists + + self._events[event.event_id] = event + self._events_by_component[event.component].append(event.event_id) + self._events_by_type[event.event_type].append(event.event_id) + self._events_by_time.append(event.event_id) + self._event_count += 1 + + if self._nx_graph is not None: + self._nx_graph.add_node(event.event_id, **event.to_dict()) + + def add_link(self, link: CausationLink) -> None: + """ + Add a causal link between two events. + + Thread-safe. + + Args: + link: The causation link to add + """ + with self._lock: + link_id = f"{link.from_event}->{link.to_event}" + + if link_id in self._links: + # Update existing link if new one is stronger + if link.strength > self._links[link_id].strength: + self._links[link_id] = link + return + + self._links[link_id] = link + self._causes[link.to_event].add(link.from_event) + self._effects[link.from_event].add(link.to_event) + self._link_count += 1 + + if self._nx_graph is not None: + self._nx_graph.add_edge( + link.from_event, + link.to_event, + **link.to_dict() + ) + + def get_event(self, event_id: str) -> Optional[Event]: + """Get an event by ID.""" + with self._lock: + return self._events.get(event_id) + + def get_causes(self, event_id: str) -> List[Event]: + """ + Get all events that directly caused this event. + + Args: + event_id: ID of the effect event + + Returns: + List of causing events + """ + with self._lock: + cause_ids = self._causes.get(event_id, set()) + return [self._events[cid] for cid in cause_ids if cid in self._events] + + def get_effects(self, event_id: str) -> List[Event]: + """ + Get all events that were directly caused by this event. + + Args: + event_id: ID of the cause event + + Returns: + List of effect events + """ + with self._lock: + effect_ids = self._effects.get(event_id, set()) + return [self._events[eid] for eid in effect_ids if eid in self._events] + + def get_link(self, from_event: str, to_event: str) -> Optional[CausationLink]: + """Get the causation link between two events.""" + with self._lock: + link_id = f"{from_event}->{to_event}" + return self._links.get(link_id) + + def get_all_links(self) -> List[CausationLink]: + """Get all causal links in the graph.""" + with self._lock: + return list(self._links.values()) + + def get_component_connections(self) -> Dict[str, Dict[str, float]]: + """ + Aggregate causal links into component-to-component connections. + + Returns: + Dict mapping (from_component, to_component) -> total strength + """ + with self._lock: + connections: Dict[tuple, float] = {} + + for link in self._links.values(): + from_event = self._events.get(link.from_event) + to_event = self._events.get(link.to_event) + + if from_event and to_event: + from_comp = from_event.component + to_comp = to_event.component + + if from_comp != to_comp: # Skip self-links + key = (from_comp, to_comp) + connections[key] = connections.get(key, 0) + link.strength + + return connections + + def get_recent_events(self, count: int = 100) -> List[Event]: + """Get the most recent events by timestamp.""" + with self._lock: + ids = self._events_by_time[-count:] + return [self._events[eid] for eid in reversed(ids)] + + def get_events_by_component(self, component: str) -> List[Event]: + """Get all events from a specific component.""" + with self._lock: + ids = self._events_by_component.get(component, []) + return [self._events[eid] for eid in ids] + + def get_events_by_type(self, event_type: str) -> List[Event]: + """Get all events of a specific type.""" + with self._lock: + ids = self._events_by_type.get(event_type, []) + return [self._events[eid] for eid in ids] + + def find_path(self, from_event: str, to_event: str) -> Optional[List[str]]: + """ + Find the shortest causal path between two events. + + Uses NetworkX if available, otherwise falls back to BFS. + + Args: + from_event: Starting event ID + to_event: Target event ID + + Returns: + List of event IDs in the path, or None if no path exists + """ + with self._lock: + if self._nx_graph is not None: + try: + return nx.shortest_path(self._nx_graph, from_event, to_event) + except nx.NetworkXNoPath: + return None + except nx.NodeNotFound: + return None + else: + # BFS fallback + return self._bfs_path(from_event, to_event) + + def _bfs_path(self, from_event: str, to_event: str) -> Optional[List[str]]: + """BFS path finding without NetworkX.""" + from collections import deque + + if from_event not in self._events or to_event not in self._events: + return None + + queue = deque([(from_event, [from_event])]) + visited = {from_event} + + while queue: + current, path = queue.popleft() + + if current == to_event: + return path + + for effect_id in self._effects.get(current, set()): + if effect_id not in visited: + visited.add(effect_id) + queue.append((effect_id, path + [effect_id])) + + return None + + def get_root_events(self) -> List[Event]: + """Get events with no causes (entry points).""" + with self._lock: + roots = [] + for event_id, event in self._events.items(): + if not self._causes.get(event_id): + roots.append(event) + return sorted(roots, key=lambda e: e.timestamp) + + def get_leaf_events(self) -> List[Event]: + """Get events with no effects (endpoints).""" + with self._lock: + leaves = [] + for event_id, event in self._events.items(): + if not self._effects.get(event_id): + leaves.append(event) + return sorted(leaves, key=lambda e: e.timestamp, reverse=True) + + def get_stats(self) -> Dict[str, Any]: + """Get statistics about the graph.""" + with self._lock: + return { + "event_count": self._event_count, + "link_count": self._link_count, + "components": list(self._events_by_component.keys()), + "event_types": list(self._events_by_type.keys()), + "root_count": len(self.get_root_events()), + "leaf_count": len(self.get_leaf_events()), + } + + def __len__(self) -> int: + return self._event_count + + def __repr__(self) -> str: + return f"" diff --git a/cascade/core/provenance.py b/cascade/core/provenance.py new file mode 100644 index 0000000000000000000000000000000000000000..245d6bb07fc9a0f71500d212ed8c569e4fb3f566 --- /dev/null +++ b/cascade/core/provenance.py @@ -0,0 +1,601 @@ +""" +CASCADE // PROVENANCE ENGINE +Cryptographic lineage tracking for neural network activations. + +Due process infrastructure for AI - immutable evidence chains +that enable governance without prescribing decisions. + +Architecture: + Input → [Layer₀] → [Layer₁] → ... → [Layerₙ] → Output + │ │ │ + ▼ ▼ ▼ + Hash₀ ──► Hash₁ ──► ... ──► Hashₙ + │ │ + └───────── Merkle Root ─────┘ + +Each hash includes: + - Tensor state (sampled for efficiency) + - Parent hashes (inputs to this layer) + - Layer identity (name, params hash) + - Execution context (order, timestamp) + +This creates verifiable, tamper-evident records of +what happened inside the network. +""" + +import hashlib +import json +import time +from dataclasses import dataclass, field, asdict +from typing import Dict, List, Optional, Any, Tuple +from collections import OrderedDict +import numpy as np + + +@dataclass +class ProvenanceRecord: + """Immutable record of a single layer's activation state.""" + + # Identity + layer_name: str + layer_idx: int + + # Lineage + state_hash: str # Hash of this layer's output + parent_hashes: List[str] # Hashes of inputs (usually 1, but attention has multiple) + params_hash: Optional[str] = None # Hash of layer weights (frozen reference) + + # Tensor metadata + shape: List[int] = field(default_factory=list) + dtype: str = "float32" + + # Statistics (for visualization, not hashed) + stats: Dict[str, float] = field(default_factory=dict) + + # Execution context + execution_order: int = 0 + timestamp: float = field(default_factory=time.time) + + # Merkle tree position + merkle_depth: int = 0 + merkle_path: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + """Serialize for JSON export.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'ProvenanceRecord': + """Deserialize from JSON.""" + return cls(**data) + + +@dataclass +class ProvenanceChain: + """Complete provenance chain for a forward pass.""" + + # Session identity + session_id: str + model_id: str + model_hash: str + + # Input/output + input_hash: str + output_hash: Optional[str] = None + + # The chain itself + records: Dict[str, ProvenanceRecord] = field(default_factory=OrderedDict) + + # External system roots (for inter-system linking) + # When this chain depends on another system's computation, + # include their merkle_root here. This creates the lattice. + external_roots: List[str] = field(default_factory=list) + + # Merkle root (computed after chain complete) + merkle_root: Optional[str] = None + + # Metadata + created_at: float = field(default_factory=time.time) + finalized: bool = False + + def add_record(self, record: ProvenanceRecord) -> None: + """Add a record to the chain. Chain must not be finalized.""" + if self.finalized: + raise ValueError("Cannot add to finalized chain") + self.records[record.layer_name] = record + + def finalize(self) -> str: + """Compute Merkle root and lock the chain.""" + if self.finalized: + return self.merkle_root + + # Build Merkle tree from record hashes + external roots + # External roots create cryptographic proof of inter-system dependency + hashes = [r.state_hash for r in self.records.values()] + hashes.extend(self.external_roots) # Include external system roots + self.merkle_root = compute_merkle_root(hashes) + self.finalized = True + return self.merkle_root + + def verify(self) -> Tuple[bool, Optional[str]]: + """Verify chain integrity.""" + if not self.finalized: + return False, "Chain not finalized" + + # Recompute Merkle root (including external roots) + hashes = [r.state_hash for r in self.records.values()] + hashes.extend(self.external_roots) # Must include external roots + computed_root = compute_merkle_root(hashes) + + if computed_root != self.merkle_root: + return False, f"Merkle root mismatch: {computed_root} != {self.merkle_root}" + + return True, None + + def link_external(self, external_merkle_root: str, source_id: str = None) -> None: + """ + Link this chain to another system's merkle root. + + This creates the neural internetwork - cryptographic proof + that this computation depended on another system's output. + + Args: + external_merkle_root: The merkle root from the external system + source_id: Optional identifier of the source system + """ + if self.finalized: + raise ValueError("Cannot link external root to finalized chain") + self.external_roots.append(external_merkle_root) + + def get_lineage(self, layer_name: str) -> List[ProvenanceRecord]: + """Trace back from a layer to its ancestors.""" + if layer_name not in self.records: + return [] + + lineage = [] + current = self.records[layer_name] + visited = set() + + def trace_back(record: ProvenanceRecord): + if record.layer_name in visited: + return + visited.add(record.layer_name) + lineage.append(record) + + for parent_hash in record.parent_hashes: + # Find record with this hash + for r in self.records.values(): + if r.state_hash == parent_hash: + trace_back(r) + break + + trace_back(current) + return lineage + + def to_dict(self) -> Dict[str, Any]: + """Serialize entire chain.""" + return { + "session_id": self.session_id, + "model_id": self.model_id, + "model_hash": self.model_hash, + "input_hash": self.input_hash, + "output_hash": self.output_hash, + "external_roots": self.external_roots, # Inter-system links + "merkle_root": self.merkle_root, + "created_at": self.created_at, + "finalized": self.finalized, + "records": {k: v.to_dict() for k, v in self.records.items()} + } + + def to_json(self, indent: int = 2) -> str: + """Export as JSON.""" + return json.dumps(self.to_dict(), indent=indent) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'ProvenanceChain': + """Deserialize from dict.""" + records = OrderedDict() + for k, v in data.get("records", {}).items(): + records[k] = ProvenanceRecord.from_dict(v) + + chain = cls( + session_id=data["session_id"], + model_id=data["model_id"], + model_hash=data["model_hash"], + input_hash=data["input_hash"], + output_hash=data.get("output_hash"), + external_roots=data.get("external_roots", []), # Inter-system links + merkle_root=data.get("merkle_root"), + created_at=data.get("created_at", time.time()), + finalized=data.get("finalized", False), + ) + chain.records = records + return chain + + +# ============================================================================= +# HASHING FUNCTIONS +# ============================================================================= + +def hash_tensor(tensor, sample_size: int = 1000) -> str: + """ + Compute deterministic hash of tensor state. + + Samples tensor for efficiency - full hash would be too slow + for large activations. Sample is deterministic (first N elements + after flatten) so hash is reproducible. + + Args: + tensor: PyTorch tensor or numpy array + sample_size: Number of elements to sample + + Returns: + 16-character hex hash + """ + # Convert to numpy if needed + if hasattr(tensor, 'detach'): + # PyTorch tensor + arr = tensor.detach().cpu().float().numpy() + elif hasattr(tensor, 'numpy'): + arr = tensor.numpy() + else: + arr = np.array(tensor) + + # Flatten and sample + flat = arr.flatten() + sample = flat[:min(sample_size, len(flat))] + + # Hash the bytes + # Include shape in hash so same values in different shapes hash differently + shape_bytes = str(arr.shape).encode('utf-8') + tensor_bytes = sample.astype(np.float32).tobytes() + + combined = shape_bytes + tensor_bytes + return hashlib.sha256(combined).hexdigest()[:16] + + +def hash_params(module) -> str: + """ + Hash a module's parameters (weights, biases). + + This creates a frozen reference to the model state at observation time. + If weights change, this hash changes. + """ + param_hashes = [] + + for name, param in module.named_parameters(recurse=False): + if param is not None: + h = hash_tensor(param.data, sample_size=500) + param_hashes.append(f"{name}:{h}") + + if not param_hashes: + return "no_params" + + combined = "|".join(sorted(param_hashes)) + return hashlib.sha256(combined.encode()).hexdigest()[:16] + + +def hash_model(model) -> str: + """ + Hash entire model state. + + This is the model's identity hash - changes if any weight changes. + """ + all_hashes = [] + + for name, param in model.named_parameters(): + h = hash_tensor(param.data, sample_size=100) + all_hashes.append(f"{name}:{h}") + + combined = "|".join(all_hashes) + return hashlib.sha256(combined.encode()).hexdigest()[:32] + + +def hash_input(data: Any) -> str: + """ + Hash input data (text, tokens, images, etc). + """ + if isinstance(data, str): + return hashlib.sha256(data.encode('utf-8')).hexdigest()[:16] + elif hasattr(data, 'detach'): + return hash_tensor(data) + elif isinstance(data, dict): + # Tokenizer output + combined = json.dumps({k: str(v) for k, v in sorted(data.items())}) + return hashlib.sha256(combined.encode()).hexdigest()[:16] + else: + return hashlib.sha256(str(data).encode()).hexdigest()[:16] + + +def compute_merkle_root(hashes: List[str]) -> str: + """ + Compute Merkle root from list of hashes. + + Standard Merkle tree construction - pairs hashes bottom-up + until single root remains. + """ + if not hashes: + return hashlib.sha256(b"empty").hexdigest()[:16] + + if len(hashes) == 1: + return hashes[0] + + # Pad to even length + if len(hashes) % 2 == 1: + hashes = hashes + [hashes[-1]] + + # Compute next level + next_level = [] + for i in range(0, len(hashes), 2): + combined = hashes[i] + hashes[i + 1] + next_hash = hashlib.sha256(combined.encode()).hexdigest()[:16] + next_level.append(next_hash) + + return compute_merkle_root(next_level) + + +# ============================================================================= +# PROVENANCE TRACKER (attaches to model) +# ============================================================================= + +class ProvenanceTracker: + """ + Tracks provenance during model forward pass. + + Usage: + tracker = ProvenanceTracker(model, model_id="gpt2") + tracker.start_session(input_text) + + # Run forward pass - hooks capture everything + output = model(**inputs) + + chain = tracker.finalize_session() + print(chain.merkle_root) + + NEW: Now writes to tape file (JSONL) for redundant logging! + Correlative with the Live Tracer - both systems log independently. + """ + + def __init__(self, model, model_id: str, log_dir: str = "./logs"): + self.model = model + self.model_id = model_id + self.model_hash = hash_model(model) + + self.hooks = [] + self.current_chain: Optional[ProvenanceChain] = None + self.execution_counter = 0 + self.last_hash = None # Track for parent linking + self.layer_hashes: Dict[str, str] = {} # layer_name -> hash + + # === TAPE FILE FOR REDUNDANT LOGGING === + from pathlib import Path + from threading import Lock + self._log_dir = Path(log_dir) + self._log_dir.mkdir(parents=True, exist_ok=True) + self._session_id = int(time.time()) + self._tape_path = self._log_dir / f"provenance_tape_{self._session_id}.jsonl" + self._tape_file = None + self._tape_lock = Lock() + self._record_count = 0 + + def start_session(self, input_data: Any) -> str: + """Start a new provenance tracking session.""" + import uuid + + session_id = str(uuid.uuid4())[:8] + input_hash = hash_input(input_data) + + self.current_chain = ProvenanceChain( + session_id=session_id, + model_id=self.model_id, + model_hash=self.model_hash, + input_hash=input_hash + ) + + self.execution_counter = 0 + self.last_hash = input_hash + self.layer_hashes = {"input": input_hash} + + # Register hooks + self._register_hooks() + + return session_id + + def _register_hooks(self): + """Register forward hooks on all modules.""" + self._remove_hooks() # Clean up any existing + + for name, module in self.model.named_modules(): + if name: # Skip root + hook = module.register_forward_hook( + self._make_hook(name) + ) + self.hooks.append(hook) + + def _make_hook(self, layer_name: str): + """Create a forward hook for a specific layer.""" + def hook(module, inp, out): + # Extract tensor + tensor = None + if hasattr(out, 'detach'): + tensor = out + elif isinstance(out, tuple) and len(out) > 0 and hasattr(out[0], 'detach'): + tensor = out[0] + elif hasattr(out, 'last_hidden_state'): + tensor = out.last_hidden_state + elif hasattr(out, 'logits'): + tensor = out.logits + + if tensor is None or not hasattr(tensor, 'numel') or tensor.numel() == 0: + return + + # Compute hashes + state_hash = hash_tensor(tensor) + params_hash = hash_params(module) + + # Determine parent hashes + # For now, use last layer's hash. More sophisticated: track actual data flow. + parent_hashes = [self.last_hash] if self.last_hash else [] + + # Compute stats + t = tensor.float() + stats = { + "mean": t.mean().item(), + "std": t.std().item(), + "min": t.min().item(), + "max": t.max().item(), + "sparsity": (tensor == 0).float().mean().item(), + } + + # Create record + record = ProvenanceRecord( + layer_name=layer_name, + layer_idx=self.execution_counter, + state_hash=state_hash, + parent_hashes=parent_hashes, + params_hash=params_hash, + shape=list(tensor.shape), + dtype=str(tensor.dtype), + stats=stats, + execution_order=self.execution_counter, + ) + + # Add to chain + if self.current_chain: + self.current_chain.add_record(record) + + # === WRITE TO TAPE (REDUNDANT LOGGING) === + self._write_to_tape(record) + + # Update tracking + self.last_hash = state_hash + self.layer_hashes[layer_name] = state_hash + self.execution_counter += 1 + self._record_count += 1 + + return hook + + def _write_to_tape(self, record: ProvenanceRecord): + """Write provenance record to tape file for redundant logging.""" + import json + try: + with self._tape_lock: + if self._tape_file is None: + self._tape_file = open(self._tape_path, "a", encoding="utf-8") + print(f"[CASCADE] 📼 Provenance tape started: {self._tape_path}") + + tape_record = { + "seq": self._record_count, + "record": record.to_dict(), + "session_id": self._session_id, + "model_id": self.model_id, + } + self._tape_file.write(json.dumps(tape_record, default=str) + "\n") + self._tape_file.flush() + except Exception as e: + pass # Don't let tape errors break the main flow + + def close_tape(self): + """Close the tape file.""" + with self._tape_lock: + if self._tape_file: + self._tape_file.close() + self._tape_file = None + print(f"[CASCADE] 📼 Provenance tape closed: {self._record_count} records → {self._tape_path}") + + def get_tape_path(self): + """Get the current tape file path.""" + return self._tape_path + + def _remove_hooks(self): + """Remove all registered hooks.""" + for hook in self.hooks: + hook.remove() + self.hooks = [] + + def finalize_session(self, output_data: Any = None) -> ProvenanceChain: + """Finalize session, compute Merkle root, return chain.""" + self._remove_hooks() + + if self.current_chain is None: + raise ValueError("No active session") + + if output_data is not None: + self.current_chain.output_hash = hash_input(output_data) + + self.current_chain.finalize() + + # Close tape (session complete) + self.close_tape() + + chain = self.current_chain + self.current_chain = None + + return chain + + +# ============================================================================= +# VERIFICATION & COMPARISON +# ============================================================================= + +def verify_chain(chain: ProvenanceChain) -> Tuple[bool, str]: + """Verify a provenance chain's integrity.""" + return chain.verify() + + +def compare_chains(chain_a: ProvenanceChain, chain_b: ProvenanceChain) -> Dict[str, Any]: + """ + Compare two provenance chains. + + Useful for: + - Same model, different inputs (where did outputs diverge?) + - Different models, same input (structural comparison) + - Same everything (reproducibility check) + """ + result = { + "model_match": chain_a.model_hash == chain_b.model_hash, + "input_match": chain_a.input_hash == chain_b.input_hash, + "output_match": chain_a.output_hash == chain_b.output_hash, + "merkle_match": chain_a.merkle_root == chain_b.merkle_root, + "divergence_points": [], + "a_only_layers": [], + "b_only_layers": [], + "matching_layers": [], + } + + a_layers = set(chain_a.records.keys()) + b_layers = set(chain_b.records.keys()) + + result["a_only_layers"] = list(a_layers - b_layers) + result["b_only_layers"] = list(b_layers - a_layers) + + # Compare matching layers + for layer in a_layers & b_layers: + rec_a = chain_a.records[layer] + rec_b = chain_b.records[layer] + + if rec_a.state_hash == rec_b.state_hash: + result["matching_layers"].append(layer) + else: + result["divergence_points"].append({ + "layer": layer, + "hash_a": rec_a.state_hash, + "hash_b": rec_b.state_hash, + "stats_a": rec_a.stats, + "stats_b": rec_b.stats, + }) + + return result + + +def export_chain_for_audit(chain: ProvenanceChain, filepath: str) -> None: + """Export chain to file for external audit.""" + with open(filepath, 'w') as f: + f.write(chain.to_json(indent=2)) + + +def import_chain_for_audit(filepath: str) -> ProvenanceChain: + """Import chain from audit file.""" + with open(filepath, 'r') as f: + data = json.load(f) + return ProvenanceChain.from_dict(data) diff --git a/cascade/core/web3_bridge.py b/cascade/core/web3_bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..70e059549a34e619ad0adbe4add3de08c311f0d6 --- /dev/null +++ b/cascade/core/web3_bridge.py @@ -0,0 +1,846 @@ +""" +CASCADE // WEB3 BRIDGE +Blockchain integration for AI provenance. + +The bridge between neural networks and decentralized infrastructure. + +┌─────────────────────────────────────────────────────────────────┐ +│ THE IMMUTABLE RECORD │ +│ │ +│ AI Inference ──► Provenance Chain ──► Merkle Root ──► Chain │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────┐ │ +│ │ ETHEREUM / SOLANA / etc │ │ +│ │ ┌───────────────────────────┐ │ │ +│ │ │ Attestation Contract │ │ │ +│ │ │ - Model hash │ │ │ +│ │ │ - Input hash │ │ │ +│ │ │ - Merkle root │ │ │ +│ │ │ - Timestamp │ │ │ +│ │ └───────────────────────────┘ │ │ +│ └─────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ IPFS / Arweave / Filecoin │ +│ (Full provenance chain storage) │ +└─────────────────────────────────────────────────────────────────┘ + +Web3 provides: + - Timestamping (block finality) + - Immutability (blockchain consensus) + - Decentralized storage (IPFS) + - Public verifiability (anyone can audit) + - Economic incentives (staking, reputation) + +This module provides: + - EIP-712 typed data signatures (Ethereum standard) + - IPFS CID computation (content addressing) + - Smart contract ABI for attestation + - Multi-chain attestation format + - NFT metadata for provenance tokens +""" + +import hashlib +import json +import time +import struct +from typing import Dict, List, Optional, Any, Tuple +from dataclasses import dataclass, field, asdict +import base64 + +try: + from .provenance import ProvenanceChain, ProvenanceRecord, compute_merkle_root +except ImportError: + from provenance import ProvenanceChain, ProvenanceRecord, compute_merkle_root + + +# ============================================================================= +# CONSTANTS +# ============================================================================= + +# EIP-712 Domain for CASCADE attestations +CASCADE_DOMAIN = { + "name": "CASCADE Provenance", + "version": "1", + "chainId": 1, # Ethereum mainnet, override for other chains + "verifyingContract": "0x0000000000000000000000000000000000000000", # Set on deployment +} + +# Attestation type definition for EIP-712 +ATTESTATION_TYPES = { + "Attestation": [ + {"name": "model_hash", "type": "bytes32"}, + {"name": "input_hash", "type": "bytes32"}, + {"name": "merkle_root", "type": "bytes32"}, + {"name": "timestamp", "type": "uint256"}, + {"name": "session_id", "type": "string"}, + {"name": "layer_count", "type": "uint256"}, + ] +} + + +# ============================================================================= +# ATTESTATION RECORD +# ============================================================================= + +@dataclass +class Web3Attestation: + """ + Blockchain-ready attestation of AI inference provenance. + + This is the "receipt" that can be posted on-chain. + Minimal data for on-chain storage, full data on IPFS. + """ + + # Core identity + model_hash: str # 32-byte hash of model weights + input_hash: str # 32-byte hash of input data + output_hash: str # 32-byte hash of output + merkle_root: str # Merkle root of provenance chain + + # Metadata + session_id: str # Unique session identifier + timestamp: int # Unix timestamp + layer_count: int # Number of layers in chain + + # Content addressing + ipfs_cid: Optional[str] = None # IPFS CID for full chain + arweave_id: Optional[str] = None # Arweave transaction ID + + # Signatures (set by wallet) + signature: Optional[str] = None # EIP-712 signature + signer: Optional[str] = None # Ethereum address + + # Chain info + chain_id: int = 1 # 1=Ethereum, 137=Polygon, etc. + contract_address: Optional[str] = None + tx_hash: Optional[str] = None # Transaction hash after posting + + def to_eip712_message(self, domain: Optional[Dict] = None) -> Dict[str, Any]: + """ + Format as EIP-712 typed data for signing. + + This is the standard Ethereum signing format that wallets understand. + """ + domain = domain or CASCADE_DOMAIN + + return { + "types": { + "EIP712Domain": [ + {"name": "name", "type": "string"}, + {"name": "version", "type": "string"}, + {"name": "chainId", "type": "uint256"}, + {"name": "verifyingContract", "type": "address"}, + ], + **ATTESTATION_TYPES + }, + "primaryType": "Attestation", + "domain": domain, + "message": { + "model_hash": self._to_bytes32(self.model_hash), + "input_hash": self._to_bytes32(self.input_hash), + "merkle_root": self._to_bytes32(self.merkle_root), + "timestamp": self.timestamp, + "session_id": self.session_id, + "layer_count": self.layer_count, + } + } + + def _to_bytes32(self, hex_str: str) -> str: + """Pad hash to bytes32 format.""" + # Remove 0x prefix if present + clean = hex_str.replace("0x", "") + # Pad to 64 chars (32 bytes) + padded = clean.zfill(64) + return "0x" + padded + + def to_contract_args(self) -> Tuple: + """ + Format for smart contract function call. + + Returns tuple matching: + function attest(bytes32 modelHash, bytes32 inputHash, bytes32 merkleRoot, + string memory sessionId, uint256 layerCount) + """ + return ( + bytes.fromhex(self.model_hash.replace("0x", "").zfill(64)), + bytes.fromhex(self.input_hash.replace("0x", "").zfill(64)), + bytes.fromhex(self.merkle_root.replace("0x", "").zfill(64)), + self.session_id, + self.layer_count, + ) + + def to_dict(self) -> Dict[str, Any]: + """Serialize for storage/transmission.""" + return asdict(self) + + def to_json(self) -> str: + """JSON export.""" + return json.dumps(self.to_dict(), indent=2) + + @classmethod + def from_chain(cls, chain: ProvenanceChain) -> 'Web3Attestation': + """Create attestation from provenance chain.""" + if not chain.finalized: + chain.finalize() + + return cls( + model_hash=chain.model_hash, + input_hash=chain.input_hash, + output_hash=chain.output_hash or "0" * 16, + merkle_root=chain.merkle_root, + session_id=chain.session_id, + timestamp=int(chain.created_at), + layer_count=len(chain.records), + ) + + +# ============================================================================= +# IPFS CONTENT ADDRESSING +# ============================================================================= + +def compute_ipfs_cid_v0(data: bytes) -> str: + """ + Compute IPFS CID v0 (Qm...) for data. + + This is a simplified computation - actual IPFS uses more complex + chunking for large files. Suitable for JSON chain data. + + CIDv0 format: Base58(0x12 || 0x20 || SHA256(data)) + """ + # SHA-256 hash + sha_hash = hashlib.sha256(data).digest() + + # Multihash prefix: 0x12 (sha2-256), 0x20 (32 bytes) + multihash = bytes([0x12, 0x20]) + sha_hash + + # Base58 encode (Bitcoin alphabet) + return base58_encode(multihash) + + +def compute_ipfs_cid_v1(data: bytes) -> str: + """ + Compute IPFS CID v1 (bafy...) for data. + + CIDv1 format: multibase || version || codec || multihash + """ + # SHA-256 hash + sha_hash = hashlib.sha256(data).digest() + + # Build CIDv1: + # 0x01 = CID version 1 + # 0x55 = raw binary codec (could also use 0x71 for dag-cbor) + # 0x12 = sha2-256 + # 0x20 = 32 bytes + cid_bytes = bytes([0x01, 0x55, 0x12, 0x20]) + sha_hash + + # Base32 lower with 'b' prefix (multibase) + import base64 + b32 = base64.b32encode(cid_bytes).decode('ascii').lower().rstrip('=') + return 'b' + b32 + + +def base58_encode(data: bytes) -> str: + """Base58 encoding (Bitcoin alphabet).""" + ALPHABET = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" + + # Count leading zeros + leading_zeros = 0 + for byte in data: + if byte == 0: + leading_zeros += 1 + else: + break + + # Convert to integer + num = int.from_bytes(data, 'big') + + # Convert to base58 + result = "" + while num > 0: + num, remainder = divmod(num, 58) + result = ALPHABET[remainder] + result + + # Add leading '1's for each leading zero byte + return '1' * leading_zeros + result + + +def chain_to_ipfs_ready(chain: ProvenanceChain) -> Tuple[bytes, str]: + """ + Prepare provenance chain for IPFS upload. + + Returns: + (data_bytes, cid) - The data to upload and its expected CID + """ + json_data = chain.to_json().encode('utf-8') + cid = compute_ipfs_cid_v0(json_data) + return json_data, cid + + +# ============================================================================= +# SMART CONTRACT ABI +# ============================================================================= + +CASCADE_ATTESTATION_ABI = [ + { + "name": "Attest", + "type": "event", + "inputs": [ + {"name": "attester", "type": "address", "indexed": True}, + {"name": "modelHash", "type": "bytes32", "indexed": True}, + {"name": "merkleRoot", "type": "bytes32", "indexed": False}, + {"name": "sessionId", "type": "string", "indexed": False}, + {"name": "timestamp", "type": "uint256", "indexed": False}, + ] + }, + { + "name": "attest", + "type": "function", + "stateMutability": "nonpayable", + "inputs": [ + {"name": "modelHash", "type": "bytes32"}, + {"name": "inputHash", "type": "bytes32"}, + {"name": "merkleRoot", "type": "bytes32"}, + {"name": "sessionId", "type": "string"}, + {"name": "layerCount", "type": "uint256"}, + ], + "outputs": [{"name": "attestationId", "type": "uint256"}] + }, + { + "name": "verify", + "type": "function", + "stateMutability": "view", + "inputs": [ + {"name": "attestationId", "type": "uint256"}, + ], + "outputs": [ + {"name": "valid", "type": "bool"}, + {"name": "attester", "type": "address"}, + {"name": "modelHash", "type": "bytes32"}, + {"name": "merkleRoot", "type": "bytes32"}, + ] + }, + { + "name": "getAttestation", + "type": "function", + "stateMutability": "view", + "inputs": [ + {"name": "attestationId", "type": "uint256"}, + ], + "outputs": [ + {"name": "attester", "type": "address"}, + {"name": "modelHash", "type": "bytes32"}, + {"name": "inputHash", "type": "bytes32"}, + {"name": "merkleRoot", "type": "bytes32"}, + {"name": "sessionId", "type": "string"}, + {"name": "layerCount", "type": "uint256"}, + {"name": "timestamp", "type": "uint256"}, + ] + }, + { + "name": "attestationsByModel", + "type": "function", + "stateMutability": "view", + "inputs": [ + {"name": "modelHash", "type": "bytes32"}, + ], + "outputs": [ + {"name": "attestationIds", "type": "uint256[]"}, + ] + }, +] + + +# Solidity source for the attestation contract +CASCADE_ATTESTATION_SOLIDITY = ''' +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.19; + +/** + * @title CascadeAttestation + * @notice On-chain attestation of AI inference provenance + * @dev Stores Merkle roots for off-chain provenance chains + */ +contract CascadeAttestation { + + struct Attestation { + address attester; + bytes32 modelHash; + bytes32 inputHash; + bytes32 merkleRoot; + string sessionId; + uint256 layerCount; + uint256 timestamp; + string ipfsCid; // Optional: full chain on IPFS + } + + // Attestation storage + mapping(uint256 => Attestation) public attestations; + uint256 public attestationCount; + + // Index by model + mapping(bytes32 => uint256[]) public attestationsByModel; + + // Index by attester + mapping(address => uint256[]) public attestationsByAttester; + + // Events + event Attested( + uint256 indexed attestationId, + address indexed attester, + bytes32 indexed modelHash, + bytes32 merkleRoot, + string sessionId + ); + + /** + * @notice Create a new attestation + * @param modelHash Hash of the model weights + * @param inputHash Hash of the input data + * @param merkleRoot Merkle root of the provenance chain + * @param sessionId Unique session identifier + * @param layerCount Number of layers in the chain + * @return attestationId The ID of the new attestation + */ + function attest( + bytes32 modelHash, + bytes32 inputHash, + bytes32 merkleRoot, + string memory sessionId, + uint256 layerCount + ) external returns (uint256 attestationId) { + attestationId = attestationCount++; + + attestations[attestationId] = Attestation({ + attester: msg.sender, + modelHash: modelHash, + inputHash: inputHash, + merkleRoot: merkleRoot, + sessionId: sessionId, + layerCount: layerCount, + timestamp: block.timestamp, + ipfsCid: "" + }); + + attestationsByModel[modelHash].push(attestationId); + attestationsByAttester[msg.sender].push(attestationId); + + emit Attested(attestationId, msg.sender, modelHash, merkleRoot, sessionId); + + return attestationId; + } + + /** + * @notice Attest with IPFS CID for full chain data + */ + function attestWithIPFS( + bytes32 modelHash, + bytes32 inputHash, + bytes32 merkleRoot, + string memory sessionId, + uint256 layerCount, + string memory ipfsCid + ) external returns (uint256 attestationId) { + attestationId = this.attest(modelHash, inputHash, merkleRoot, sessionId, layerCount); + attestations[attestationId].ipfsCid = ipfsCid; + return attestationId; + } + + /** + * @notice Verify an attestation exists and return core data + */ + function verify(uint256 attestationId) external view returns ( + bool valid, + address attester, + bytes32 modelHash, + bytes32 merkleRoot + ) { + if (attestationId >= attestationCount) { + return (false, address(0), bytes32(0), bytes32(0)); + } + + Attestation storage a = attestations[attestationId]; + return (true, a.attester, a.modelHash, a.merkleRoot); + } + + /** + * @notice Get all attestations for a model + */ + function getModelAttestations(bytes32 modelHash) external view returns (uint256[] memory) { + return attestationsByModel[modelHash]; + } + + /** + * @notice Get all attestations by an address + */ + function getAttesterAttestations(address attester) external view returns (uint256[] memory) { + return attestationsByAttester[attester]; + } +} +''' + + +# ============================================================================= +# NFT METADATA (for provenance tokens) +# ============================================================================= + +def generate_nft_metadata(chain: ProvenanceChain, + image_url: Optional[str] = None, + animation_url: Optional[str] = None) -> Dict[str, Any]: + """ + Generate ERC-721 compatible metadata for a provenance NFT. + + Each unique model×input×output combination could be an NFT, + proving that this specific inference happened. + """ + if not chain.finalized: + chain.finalize() + + # Generate attributes from chain + attributes = [ + {"trait_type": "Model Hash", "value": chain.model_hash[:16]}, + {"trait_type": "Input Hash", "value": chain.input_hash}, + {"trait_type": "Merkle Root", "value": chain.merkle_root}, + {"trait_type": "Layer Count", "value": len(chain.records)}, + {"trait_type": "Timestamp", "value": int(chain.created_at)}, + ] + + # Add layer statistics as traits + if chain.records: + total_params = 0 + layer_types = set() + for record in chain.records.values(): + if record.params_hash != "no_params": + total_params += 1 + # Extract layer type from name + parts = record.layer_name.split('.') + if len(parts) >= 2: + layer_types.add(parts[-1]) + + attributes.append({"trait_type": "Parameterized Layers", "value": total_params}) + for lt in list(layer_types)[:5]: # Max 5 layer types + attributes.append({"trait_type": f"Has {lt}", "value": "Yes"}) + + return { + "name": f"CASCADE Provenance #{chain.session_id}", + "description": f"Cryptographic proof of AI inference. Model: {chain.model_id}. " + f"This NFT attests that a specific input was processed through " + f"the model, producing a verifiable Merkle root of all layer activations.", + "image": image_url or "ipfs://QmDefaultCascadeImage", # Placeholder + "animation_url": animation_url, # Could link to 3D visualization + "external_url": f"https://cascade.ai/verify/{chain.session_id}", + "attributes": attributes, + "properties": { + "model_id": chain.model_id, + "model_hash": chain.model_hash, + "input_hash": chain.input_hash, + "output_hash": chain.output_hash, + "merkle_root": chain.merkle_root, + "session_id": chain.session_id, + "layer_count": len(chain.records), + "created_at": chain.created_at, + } + } + + +# ============================================================================= +# MULTI-CHAIN SUPPORT +# ============================================================================= + +CHAIN_CONFIGS = { + "ethereum": { + "chain_id": 1, + "name": "Ethereum Mainnet", + "explorer": "https://etherscan.io", + "native_token": "ETH", + }, + "polygon": { + "chain_id": 137, + "name": "Polygon", + "explorer": "https://polygonscan.com", + "native_token": "MATIC", + }, + "arbitrum": { + "chain_id": 42161, + "name": "Arbitrum One", + "explorer": "https://arbiscan.io", + "native_token": "ETH", + }, + "optimism": { + "chain_id": 10, + "name": "Optimism", + "explorer": "https://optimistic.etherscan.io", + "native_token": "ETH", + }, + "base": { + "chain_id": 8453, + "name": "Base", + "explorer": "https://basescan.org", + "native_token": "ETH", + }, + "solana": { + "chain_id": -1, # Not EVM + "name": "Solana", + "explorer": "https://solscan.io", + "native_token": "SOL", + }, +} + + +def get_chain_config(chain_name: str) -> Dict[str, Any]: + """Get configuration for a specific blockchain.""" + return CHAIN_CONFIGS.get(chain_name.lower(), CHAIN_CONFIGS["ethereum"]) + + +# ============================================================================= +# WEB3 EXPORT UTILITIES +# ============================================================================= + +def export_for_web3(chain: ProvenanceChain, + chain_name: str = "ethereum", + include_full_chain: bool = True) -> Dict[str, Any]: + """ + Export provenance chain in Web3-ready format. + + Returns everything needed to post attestation on-chain. + """ + attestation = Web3Attestation.from_chain(chain) + chain_config = get_chain_config(chain_name) + + result = { + "attestation": attestation.to_dict(), + "eip712": attestation.to_eip712_message({ + **CASCADE_DOMAIN, + "chainId": chain_config["chain_id"] + }), + "contract_abi": CASCADE_ATTESTATION_ABI, + "chain_config": chain_config, + } + + if include_full_chain: + data, cid = chain_to_ipfs_ready(chain) + result["ipfs"] = { + "data": base64.b64encode(data).decode('ascii'), + "cid": cid, + "size_bytes": len(data), + } + + return result + + +def generate_verification_page(attestation: Web3Attestation, + chain: Optional[ProvenanceChain] = None) -> str: + """ + Generate an HTML verification page for an attestation. + + This can be hosted anywhere and allows public verification. + """ + records_html = "" + if chain: + for record in chain.records.values(): + records_html += f""" + + {record.layer_name} + {record.state_hash} + {record.shape} + {record.stats.get('mean', 0):.4f} + + """ + + return f""" + + + CASCADE Provenance Verification + + + + +
+

🔗 CASCADE Provenance Verification

+ +
+ Merkle Root: {attestation.merkle_root} +
+ +

Attestation Details

+

Session ID

+
{attestation.session_id}
+ +

Model Hash

+
{attestation.model_hash}
+ +

Input Hash

+
{attestation.input_hash}
+ +

Output Hash

+
{attestation.output_hash}
+ +

Timestamp

+
{attestation.timestamp} ({time.strftime('%Y-%m-%d %H:%M:%S UTC', time.gmtime(attestation.timestamp))})
+ +

Layer Count

+
{attestation.layer_count} layers
+ + {"

Provenance Chain

" + records_html + "
LayerState HashShapeMean
" if chain else ""} + +

On-Chain Verification

+

{"✓ Verified on " + get_chain_config('ethereum')['name'] + "" if attestation.tx_hash else "⏳ Pending on-chain attestation"}

+ {f"

Transaction

" if attestation.tx_hash else ""} + +

IPFS Storage

+

{f"{attestation.ipfs_cid}" if attestation.ipfs_cid else "Full chain not yet pinned to IPFS"}

+ +
+

CASCADE Provenance Engine • Due process infrastructure for AI

+
+ + +""" + + +# ============================================================================= +# SIGNATURE UTILITIES (for wallet integration) +# ============================================================================= + +def prepare_for_signing(attestation: Web3Attestation, + chain_name: str = "ethereum") -> Dict[str, Any]: + """ + Prepare attestation for wallet signing (MetaMask, etc). + + Returns the EIP-712 message that wallets can sign. + """ + chain_config = get_chain_config(chain_name) + + eip712 = attestation.to_eip712_message({ + **CASCADE_DOMAIN, + "chainId": chain_config["chain_id"] + }) + + return { + "method": "eth_signTypedData_v4", + "params": [ + None, # Address filled by wallet + json.dumps(eip712) + ], + "display": { + "title": "Sign CASCADE Attestation", + "description": f"Attest that model {attestation.model_hash[:16]}... " + f"processed input {attestation.input_hash[:16]}...", + "merkle_root": attestation.merkle_root, + } + } + + +def verify_signature(attestation: Web3Attestation, + signature: str, + expected_signer: str) -> Tuple[bool, str]: + """ + Verify an EIP-712 signature. + + Note: Full verification requires eth_utils/web3.py. + This is a structural check only. + """ + if not signature or len(signature) < 130: + return False, "Invalid signature length" + + if not signature.startswith("0x"): + return False, "Signature must start with 0x" + + # Extract r, s, v components + try: + sig_bytes = bytes.fromhex(signature[2:]) + if len(sig_bytes) != 65: + return False, f"Signature must be 65 bytes, got {len(sig_bytes)}" + + r = sig_bytes[:32] + s = sig_bytes[32:64] + v = sig_bytes[64] + + # v should be 27 or 28 (or 0/1 for some implementations) + if v not in [0, 1, 27, 28]: + return False, f"Invalid v value: {v}" + + # Structural validation passed + # Full cryptographic verification requires ecrecover + return True, "Signature structure valid (full verification requires web3.py)" + + except Exception as e: + return False, f"Signature parsing error: {str(e)}" + + +# ============================================================================= +# CONVENIENCE FUNCTIONS +# ============================================================================= + +def attest_inference(chain: ProvenanceChain, + chain_name: str = "ethereum") -> Web3Attestation: + """ + One-liner to create attestation from provenance chain. + + Usage: + attestation = attest_inference(chain) + print(attestation.merkle_root) + """ + if not chain.finalized: + chain.finalize() + + attestation = Web3Attestation.from_chain(chain) + + # Compute IPFS CID + data, cid = chain_to_ipfs_ready(chain) + attestation.ipfs_cid = cid + + # Set chain + attestation.chain_id = get_chain_config(chain_name)["chain_id"] + + return attestation + + +def quick_verify(merkle_root: str, layer_hashes: List[str]) -> bool: + """ + Quick verification that layer hashes produce expected Merkle root. + """ + computed = compute_merkle_root(layer_hashes) + return computed == merkle_root + + +# ============================================================================= +# COMMAND LINE INTERFACE +# ============================================================================= + +if __name__ == "__main__": + import sys + + print("CASCADE // WEB3 BRIDGE") + print("=" * 50) + print() + print("Smart Contract (Solidity):") + print("-" * 50) + print(CASCADE_ATTESTATION_SOLIDITY[:500] + "...") + print() + print("Contract ABI:") + print("-" * 50) + print(json.dumps(CASCADE_ATTESTATION_ABI, indent=2)[:500] + "...") + print() + print("Supported Chains:") + print("-" * 50) + for name, config in CHAIN_CONFIGS.items(): + print(f" {name}: Chain ID {config['chain_id']}") + print() + print("Usage:") + print(" from cascade.core.web3_bridge import attest_inference, export_for_web3") + print(" attestation = attest_inference(provenance_chain)") + print(" web3_data = export_for_web3(provenance_chain, 'polygon')") diff --git a/cascade/data/__init__.py b/cascade/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d21a5767c82184de894016faa48bb2257cb0ee4 --- /dev/null +++ b/cascade/data/__init__.py @@ -0,0 +1,112 @@ +""" +CASCADE Data Observatory + +Dataset observation with the same rigor as model observation. +Tracks provenance, schema, lineage using W3C PROV-O standard. +""" + +from .entities import ( + DatasetEntity, + Activity, + Agent, + Relationship, + RelationType, + ActivityType, + AgentType, + create_system_agent, + create_model_agent, + create_user_agent, +) +from .observer import DatasetObserver, ObservationContext +from .provenance import ProvenanceGraph +from .schema import SchemaObserver, DatasetSchema, FieldSchema, hash_content +from .croissant import CroissantExporter, export_to_croissant +from .hub import HubIntegration, AccountabilityBundle, push_to_hub, pull_from_hub +from .license import ( + SPDXLicense, + LicenseCategory, + LicenseRestriction, + LicenseCompatibility, + LicenseAnalyzer, + SPDX_LICENSES, + get_license, + check_license_compatibility, + get_derived_license, +) +from .pii import ( + PIIType, + PIISeverity, + PIIMatch, + PIIScanResult, + PIIScanner, + scan_for_pii, + quick_pii_check, +) +from .live import ( + LiveDocumentTracer, + TraceEvent, + TraceEventType, + DocumentSpan, + DocumentAssociation, + ConsoleTraceRenderer, + create_live_tracer, +) + +__all__ = [ + # Entities (PROV-O) + "DatasetEntity", + "Activity", + "Agent", + "Relationship", + "RelationType", + "ActivityType", + "AgentType", + "create_system_agent", + "create_model_agent", + "create_user_agent", + # Observer + "DatasetObserver", + "ObservationContext", + # Provenance + "ProvenanceGraph", + # Schema + "SchemaObserver", + "DatasetSchema", + "FieldSchema", + "hash_content", + # Export + "CroissantExporter", + "export_to_croissant", + # Accountability + "AccountabilityBundle", + # Hub + "HubIntegration", + "push_to_hub", + "pull_from_hub", + # License + "SPDXLicense", + "LicenseCategory", + "LicenseRestriction", + "LicenseCompatibility", + "LicenseAnalyzer", + "SPDX_LICENSES", + "get_license", + "check_license_compatibility", + "get_derived_license", + # PII Detection + "PIIType", + "PIISeverity", + "PIIMatch", + "PIIScanResult", + "PIIScanner", + "scan_for_pii", + "quick_pii_check", + # Live Document Tracing + "LiveDocumentTracer", + "TraceEvent", + "TraceEventType", + "DocumentSpan", + "DocumentAssociation", + "ConsoleTraceRenderer", + "create_live_tracer", +] diff --git a/cascade/data/croissant.py b/cascade/data/croissant.py new file mode 100644 index 0000000000000000000000000000000000000000..e90ad4f9875cb6e92f5b5751e0099117fa330ca3 --- /dev/null +++ b/cascade/data/croissant.py @@ -0,0 +1,289 @@ +""" +Croissant Exporter + +Exports provenance graph to MLCommons Croissant format. +Croissant is the emerging standard for ML dataset metadata. + +Reference: https://github.com/mlcommons/croissant +""" + +import json +import time +from datetime import datetime +from typing import Any, Dict, List, Optional + +from .entities import DatasetEntity, Activity, Agent +from .provenance import ProvenanceGraph + + +class CroissantExporter: + """ + Export provenance to Croissant JSON-LD format. + + Croissant layers: + 1. Metadata - description, license, citation + 2. Resources - file descriptions + 3. Structure - record sets and fields + 4. ML Semantics - task types, splits + + We add provenance as an extension. + """ + + CROISSANT_VERSION = "1.0" + CROISSANT_CONTEXT = "http://mlcommons.org/croissant/1.0" + + def __init__(self, graph: ProvenanceGraph): + self.graph = graph + + def export( + self, + name: str = None, + description: str = None, + license_url: str = None, + citation: str = None, + url: str = None, + include_provenance: bool = True, + ) -> Dict[str, Any]: + """ + Export to Croissant JSON-LD. + + Args: + name: Dataset name (defaults to graph name) + description: Dataset description + license_url: License URL + citation: Citation text + url: Dataset URL + include_provenance: Whether to include CASCADE provenance extension + + Returns: + Croissant JSON-LD document + """ + name = name or self.graph.name + + doc = { + "@context": { + "@vocab": "http://schema.org/", + "sc": "http://schema.org/", + "cr": "http://mlcommons.org/croissant/", + "rai": "http://mlcommons.org/croissant/RAI/", + "spdx": "http://spdx.org/rdf/terms#", + }, + "@type": "sc:Dataset", + "name": name, + "conformsTo": self.CROISSANT_CONTEXT, + "dateCreated": datetime.fromtimestamp(self.graph.created_at).isoformat(), + "dateModified": datetime.now().isoformat(), + } + + if description: + doc["description"] = description + if license_url: + doc["license"] = license_url + if citation: + doc["citation"] = citation + if url: + doc["url"] = url + + # Add distributions (file objects) + doc["distribution"] = self._build_distributions() + + # Add record sets + doc["recordSet"] = self._build_record_sets() + + # Add provenance extension + if include_provenance: + doc["cr:provenance"] = self._build_provenance_extension() + + return doc + + def _build_distributions(self) -> List[Dict[str, Any]]: + """Build distribution (FileObject) entries.""" + distributions = [] + + for entity in self.graph.list_entities(): + dist = { + "@type": "cr:FileObject", + "@id": entity.id, + "name": entity.name, + } + + if entity.source_uri: + dist["contentUrl"] = entity.source_uri + + if entity.content_hash: + dist["sha256"] = entity.content_hash + + # License information (SPDX) + if entity.license_id: + dist["spdx:license"] = entity.license_id + if entity.license_url: + dist["sc:license"] = entity.license_url + else: + # Auto-generate SPDX license URL + dist["sc:license"] = f"https://spdx.org/licenses/{entity.license_id}.html" + + # Infer encoding format from source type + format_map = { + "hf_dataset": "application/x-arrow", + "hf_hub": "application/x-arrow", + "parquet": "application/x-parquet", + "csv": "text/csv", + "json": "application/json", + "jsonl": "application/x-jsonlines", + } + if entity.source_type in format_map: + dist["encodingFormat"] = format_map[entity.source_type] + + if entity.size_bytes: + dist["contentSize"] = f"{entity.size_bytes} bytes" + + distributions.append(dist) + + return distributions + + def _build_record_sets(self) -> List[Dict[str, Any]]: + """Build RecordSet entries from entity schemas.""" + record_sets = [] + + for entity in self.graph.list_entities(): + schema = entity.attributes.get("schema") + if not schema: + continue + + fields = [] + for field_name, field_info in schema.get("fields", {}).items(): + field_entry = { + "@type": "cr:Field", + "name": field_name, + "dataType": self._map_dtype_to_croissant(field_info.get("dtype", "string")), + } + + if field_info.get("description"): + field_entry["description"] = field_info["description"] + + # Source reference + field_entry["source"] = { + "fileObject": {"@id": entity.id}, + "extract": {"column": field_name}, + } + + fields.append(field_entry) + + if fields: + record_set = { + "@type": "cr:RecordSet", + "@id": f"recordset_{entity.id}", + "name": f"{entity.name}_records", + "field": fields, + } + + if entity.record_count: + record_set["cr:recordCount"] = entity.record_count + + record_sets.append(record_set) + + return record_sets + + def _map_dtype_to_croissant(self, dtype: str) -> str: + """Map internal dtype to Croissant/schema.org type.""" + type_map = { + "string": "sc:Text", + "int8": "sc:Integer", + "int16": "sc:Integer", + "int32": "sc:Integer", + "int64": "sc:Integer", + "uint8": "sc:Integer", + "uint16": "sc:Integer", + "uint32": "sc:Integer", + "uint64": "sc:Integer", + "float16": "sc:Float", + "float32": "sc:Float", + "float64": "sc:Float", + "bool": "sc:Boolean", + "binary": "sc:Text", # Base64 encoded + "image": "sc:ImageObject", + "audio": "sc:AudioObject", + "categorical": "sc:Text", # With enumeration + "list": "sc:ItemList", + "struct": "sc:StructuredValue", + } + return type_map.get(dtype, "sc:Text") + + def _build_provenance_extension(self) -> Dict[str, Any]: + """Build CASCADE provenance extension.""" + return { + "@type": "cascade:ProvenanceGraph", + "cascade:rootHash": self.graph.root_hash, + "cascade:createdAt": datetime.fromtimestamp(self.graph.created_at).isoformat(), + + # Entities with lineage + "cascade:entities": [ + { + "@id": e.id, + "cascade:name": e.name, + "cascade:contentHash": e.content_hash, + "cascade:schemaHash": e.schema_hash, + "cascade:version": e.version, + "cascade:recordCount": e.record_count, + "cascade:derivedFrom": self.graph.get_lineage(e.id, "upstream"), + } + for e in self.graph.list_entities() + ], + + # Activities + "cascade:activities": [ + { + "@id": a.id, + "cascade:type": a.activity_type.value, + "cascade:name": a.name, + "cascade:startedAt": datetime.fromtimestamp(a.started_at).isoformat() if a.started_at else None, + "cascade:endedAt": datetime.fromtimestamp(a.ended_at).isoformat() if a.ended_at else None, + "cascade:inputs": a.inputs, + "cascade:outputs": a.outputs, + "cascade:parameters": a.parameters, + } + for a in self.graph.list_activities() + ], + + # Agents + "cascade:agents": [ + { + "@id": a.id, + "cascade:type": a.agent_type.value, + "cascade:name": a.name, + "cascade:version": a.version, + } + for a in self.graph.list_agents() + ], + } + + def to_json(self, **kwargs) -> str: + """Export to JSON string.""" + return json.dumps(self.export(**kwargs), indent=2, default=str) + + def save(self, path: str, **kwargs): + """Save to file.""" + with open(path, "w", encoding="utf-8") as f: + f.write(self.to_json(**kwargs)) + + +def export_to_croissant( + graph: ProvenanceGraph, + name: str = None, + description: str = None, + **kwargs, +) -> Dict[str, Any]: + """ + Convenience function to export provenance to Croissant. + + Args: + graph: The provenance graph to export + name: Dataset name + description: Dataset description + **kwargs: Additional export options + + Returns: + Croissant JSON-LD document + """ + exporter = CroissantExporter(graph) + return exporter.export(name=name, description=description, **kwargs) diff --git a/cascade/data/entities.py b/cascade/data/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..8950ffe6812e53c00318de7544419a256928244e --- /dev/null +++ b/cascade/data/entities.py @@ -0,0 +1,349 @@ +""" +PROV-O Entities for Dataset Observation + +W3C PROV Data Model: +- Entity: A physical, digital, or conceptual thing (the dataset) +- Activity: Something that occurs over time and acts upon entities +- Agent: Something that bears responsibility for an activity + +Relationships: +- wasGeneratedBy: Entity → Activity +- wasDerivedFrom: Entity → Entity +- wasAttributedTo: Entity → Agent +- used: Activity → Entity +- wasAssociatedWith: Activity → Agent +""" + +import hashlib +import json +import time +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Union + + +class RelationType(Enum): + """W3C PROV-O relationship types.""" + # Entity relationships + WAS_GENERATED_BY = "wasGeneratedBy" # Entity → Activity + WAS_DERIVED_FROM = "wasDerivedFrom" # Entity → Entity + WAS_ATTRIBUTED_TO = "wasAttributedTo" # Entity → Agent + WAS_REVISION_OF = "wasRevisionOf" # Entity → Entity (versioning) + HAD_PRIMARY_SOURCE = "hadPrimarySource" # Entity → Entity + + # Activity relationships + USED = "used" # Activity → Entity + WAS_ASSOCIATED_WITH = "wasAssociatedWith" # Activity → Agent + WAS_INFORMED_BY = "wasInformedBy" # Activity → Activity + WAS_STARTED_BY = "wasStartedBy" # Activity → Entity + WAS_ENDED_BY = "wasEndedBy" # Activity → Entity + + # Agent relationships + ACTED_ON_BEHALF_OF = "actedOnBehalfOf" # Agent → Agent + + +@dataclass +class Relationship: + """A provenance relationship between two nodes.""" + relation_type: RelationType + source_id: str + target_id: str + timestamp: float = field(default_factory=time.time) + attributes: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "type": self.relation_type.value, + "source": self.source_id, + "target": self.target_id, + "timestamp": self.timestamp, + "attributes": self.attributes, + } + + def to_prov_n(self) -> str: + """Export as PROV-N notation.""" + return f"{self.relation_type.value}({self.source_id}, {self.target_id})" + + +@dataclass +class DatasetEntity: + """ + A dataset entity in the provenance graph. + + Corresponds to prov:Entity - any physical, digital, or conceptual thing. + In our case: a dataset, a version of a dataset, or a split. + """ + id: str + name: str + + # Content identification + content_hash: Optional[str] = None # SHA-256 of data content + schema_hash: Optional[str] = None # SHA-256 of schema/features + + # Versioning + version: Optional[str] = None + previous_version: Optional[str] = None + + # Source + source_type: str = "unknown" # hf_hub, local, s3, gcs, etc. + source_uri: Optional[str] = None + + # License (SPDX identifier) + license_id: Optional[str] = None # e.g., "MIT", "CC-BY-4.0", "Apache-2.0" + license_url: Optional[str] = None # URL to license text + + # Statistics + record_count: Optional[int] = None + size_bytes: Optional[int] = None + splits: Dict[str, int] = field(default_factory=dict) # split_name → count + + # Metadata + attributes: Dict[str, Any] = field(default_factory=dict) + + # Timestamps + created_at: float = field(default_factory=time.time) + + def __post_init__(self): + """Generate ID if not provided.""" + if not self.id: + self.id = f"entity:{self.name}:{int(self.created_at * 1000)}" + + def compute_hash(self) -> str: + """Compute entity hash from content.""" + content = json.dumps({ + "id": self.id, + "name": self.name, + "content_hash": self.content_hash, + "schema_hash": self.schema_hash, + "version": self.version, + "record_count": self.record_count, + }, sort_keys=True) + return hashlib.sha256(content.encode()).hexdigest() + + def to_dict(self) -> Dict[str, Any]: + return { + "@type": "prov:Entity", + "@id": self.id, + "name": self.name, + "content_hash": self.content_hash, + "schema_hash": self.schema_hash, + "version": self.version, + "previous_version": self.previous_version, + "source_type": self.source_type, + "source_uri": self.source_uri, + "license_id": self.license_id, + "license_url": self.license_url, + "record_count": self.record_count, + "size_bytes": self.size_bytes, + "splits": self.splits, + "attributes": self.attributes, + "created_at": self.created_at, + } + + def to_prov_n(self) -> str: + """Export as PROV-N notation.""" + attrs = ", ".join([ + f'prov:label="{self.name}"', + f'cascade:contentHash="{self.content_hash or "unknown"}"', + f'cascade:recordCount="{self.record_count or 0}"', + f'cascade:license="{self.license_id or "unknown"}"', + ]) + return f"entity({self.id}, [{attrs}])" + + +class ActivityType(Enum): + """Types of dataset activities.""" + INGEST = "ingest" # Load from source + TRANSFORM = "transform" # Filter, map, join, etc. + SPLIT = "split" # Train/test/val split + AUGMENT = "augment" # Data augmentation + CLEAN = "clean" # Cleaning/preprocessing + MERGE = "merge" # Combining datasets + SAMPLE = "sample" # Sampling/subsetting + EXPORT = "export" # Export to format + TRAIN = "train" # Model training (consumption) + EVALUATE = "evaluate" # Model evaluation + INFERENCE = "inference" # Model inference + ENTITY_RESOLUTION = "entity_resolution" # Data Unity matching + + +@dataclass +class Activity: + """ + An activity in the provenance graph. + + Corresponds to prov:Activity - something that occurs over time + and acts upon or with entities. + """ + id: str + activity_type: ActivityType + name: str + + # Timing + started_at: Optional[float] = None + ended_at: Optional[float] = None + + # Input/Output tracking + inputs: List[str] = field(default_factory=list) # Entity IDs + outputs: List[str] = field(default_factory=list) # Entity IDs + + # Agent who performed this + agent_id: Optional[str] = None + + # Parameters/configuration used + parameters: Dict[str, Any] = field(default_factory=dict) + + # Metadata + attributes: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.id: + self.id = f"activity:{self.activity_type.value}:{int(time.time() * 1000)}" + if self.started_at is None: + self.started_at = time.time() + + def start(self): + """Mark activity as started.""" + self.started_at = time.time() + + def end(self): + """Mark activity as ended.""" + self.ended_at = time.time() + + @property + def duration(self) -> Optional[float]: + """Duration in seconds.""" + if self.started_at and self.ended_at: + return self.ended_at - self.started_at + return None + + def add_input(self, entity_id: str): + """Record an input entity.""" + if entity_id not in self.inputs: + self.inputs.append(entity_id) + + def add_output(self, entity_id: str): + """Record an output entity.""" + if entity_id not in self.outputs: + self.outputs.append(entity_id) + + def to_dict(self) -> Dict[str, Any]: + return { + "@type": "prov:Activity", + "@id": self.id, + "activity_type": self.activity_type.value, + "name": self.name, + "started_at": self.started_at, + "ended_at": self.ended_at, + "duration": self.duration, + "inputs": self.inputs, + "outputs": self.outputs, + "agent_id": self.agent_id, + "parameters": self.parameters, + "attributes": self.attributes, + } + + def to_prov_n(self) -> str: + """Export as PROV-N notation.""" + start = datetime.fromtimestamp(self.started_at).isoformat() if self.started_at else "-" + end = datetime.fromtimestamp(self.ended_at).isoformat() if self.ended_at else "-" + attrs = f'prov:label="{self.name}", cascade:type="{self.activity_type.value}"' + return f"activity({self.id}, {start}, {end}, [{attrs}])" + + +class AgentType(Enum): + """Types of agents.""" + PERSON = "person" + ORGANIZATION = "organization" + SOFTWARE = "software" + MODEL = "model" + PIPELINE = "pipeline" + SYSTEM = "system" + + +@dataclass +class Agent: + """ + An agent in the provenance graph. + + Corresponds to prov:Agent - something that bears responsibility + for an activity taking place. + """ + id: str + agent_type: AgentType + name: str + + # For software/model agents + version: Optional[str] = None + + # For organizational hierarchy + parent_agent_id: Optional[str] = None + + # Contact/identification + identifier: Optional[str] = None # HF username, email, etc. + + # Metadata + attributes: Dict[str, Any] = field(default_factory=dict) + + # Timestamp + created_at: float = field(default_factory=time.time) + + def __post_init__(self): + if not self.id: + self.id = f"agent:{self.agent_type.value}:{self.name}".replace(" ", "_").lower() + + def to_dict(self) -> Dict[str, Any]: + return { + "@type": "prov:Agent", + "@id": self.id, + "agent_type": self.agent_type.value, + "name": self.name, + "version": self.version, + "parent_agent_id": self.parent_agent_id, + "identifier": self.identifier, + "attributes": self.attributes, + "created_at": self.created_at, + } + + def to_prov_n(self) -> str: + """Export as PROV-N notation.""" + attrs = f'prov:label="{self.name}", cascade:type="{self.agent_type.value}"' + if self.version: + attrs += f', cascade:version="{self.version}"' + return f"agent({self.id}, [{attrs}])" + + +# Convenience factory functions +def create_system_agent(name: str = "cascade", version: str = "1.0.0") -> Agent: + """Create a system agent for automated operations.""" + return Agent( + id=f"agent:system:{name}", + agent_type=AgentType.SYSTEM, + name=name, + version=version, + ) + + +def create_model_agent(model_id: str, version: str = None) -> Agent: + """Create an agent representing an ML model.""" + return Agent( + id=f"agent:model:{model_id.replace('/', '_')}", + agent_type=AgentType.MODEL, + name=model_id, + version=version, + identifier=model_id, + ) + + +def create_user_agent(username: str, org: str = None) -> Agent: + """Create an agent representing a user.""" + agent = Agent( + id=f"agent:person:{username}", + agent_type=AgentType.PERSON, + name=username, + identifier=username, + ) + if org: + agent.parent_agent_id = f"agent:organization:{org}" + return agent diff --git a/cascade/data/hub.py b/cascade/data/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..a06e52c9ccfe919844a056524c8db9eb99aee3a7 --- /dev/null +++ b/cascade/data/hub.py @@ -0,0 +1,533 @@ +""" +HuggingFace Hub Integration + +Push and pull dataset provenance to/from HuggingFace Hub. + +Exports complete W3C PROV-O accountability bundle: +- cascade_provenance.json (CASCADE native format) +- prov_o.jsonld (W3C PROV-O JSON-LD - interoperable) +- prov_n.txt (W3C PROV-N notation - human readable) +- activities.jsonl (Activity log for audit) +- agents.json (Agent attributions) +- croissant.json (MLCommons Croissant) +""" + +import json +import time +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from .provenance import ProvenanceGraph +from .croissant import CroissantExporter + + +class AccountabilityBundle: + """ + Complete W3C PROV-O accountability package. + + When a dataset is extracted, this bundle provides full audit trail: + - Who created/modified it (agents) + - What transformations occurred (activities) + - Where it came from (entity lineage) + - When everything happened (timestamps) + - How to verify integrity (hashes) + """ + + def __init__(self, graph: ProvenanceGraph): + self.graph = graph + self.created_at = datetime.now(timezone.utc).isoformat() + + def to_prov_o_jsonld(self) -> Dict[str, Any]: + """Export W3C PROV-O JSON-LD (interoperable standard).""" + return self.graph.to_prov_jsonld() + + def to_prov_n(self) -> str: + """Export W3C PROV-N notation (human readable).""" + return self.graph.to_prov_n() + + def to_activity_log(self) -> List[Dict[str, Any]]: + """Export activity log for audit (JSONL format).""" + activities = [] + for activity in self.graph.list_activities(): + activities.append({ + "id": activity.id, + "name": activity.name, + "type": activity.activity_type.value, + "started_at": datetime.fromtimestamp(activity.started_at).isoformat() if activity.started_at else None, + "ended_at": datetime.fromtimestamp(activity.ended_at).isoformat() if activity.ended_at else None, + "duration_seconds": activity.duration, + "inputs": activity.inputs, + "outputs": activity.outputs, + "parameters": activity.parameters, + "attributes": activity.attributes, + }) + return activities + + def to_agent_attributions(self) -> Dict[str, Any]: + """Export agent attributions for accountability.""" + agents = {} + for agent in self.graph.list_agents(): + agents[agent.id] = { + "name": agent.name, + "type": agent.agent_type.value, + "version": agent.version, + "identifier": agent.identifier, + "attributes": agent.attributes, + } + + # Build attribution matrix: which agent did what + attributions = [] + for rel in self.graph.list_relationships(): + if rel.relation_type.value == "wasAssociatedWith": + activity = self.graph.get_activity(rel.source_id) + agent = self.graph.get_agent(rel.target_id) + if activity and agent: + attributions.append({ + "activity_id": activity.id, + "activity_name": activity.name, + "agent_id": agent.id, + "agent_name": agent.name, + "timestamp": datetime.fromtimestamp(activity.started_at).isoformat() if activity.started_at else None, + }) + + return { + "agents": agents, + "attributions": attributions, + "total_agents": len(agents), + "total_attributions": len(attributions), + } + + def to_integrity_manifest(self) -> Dict[str, Any]: + """Export integrity manifest for verification.""" + is_valid, invalid_ids = self.graph.verify_integrity() + + return { + "root_hash": self.graph.root_hash, + "created_at": self.created_at, + "is_valid": is_valid, + "invalid_entity_ids": invalid_ids, + "entity_hashes": { + entity.id: { + "content_hash": entity.content_hash, + "schema_hash": entity.schema_hash, + } + for entity in self.graph.list_entities() + }, + "verification_note": ( + "To verify: recompute content hashes and compare against this manifest. " + "Any mismatch indicates data tampering." + ), + } + + def export(self, output_dir: str): + """Export all accountability artifacts to a directory.""" + import os + os.makedirs(output_dir, exist_ok=True) + + # 1. CASCADE provenance JSON + with open(os.path.join(output_dir, "cascade_provenance.json"), "w") as f: + json.dump(self.graph.to_dict(), f, indent=2, default=str) + + # 2. W3C PROV-O JSON-LD + with open(os.path.join(output_dir, "prov_o.jsonld"), "w") as f: + json.dump(self.to_prov_o_jsonld(), f, indent=2, default=str) + + # 3. W3C PROV-N notation + with open(os.path.join(output_dir, "prov_n.txt"), "w") as f: + f.write(self.to_prov_n()) + + # 4. Activity log + with open(os.path.join(output_dir, "activities.jsonl"), "w") as f: + for activity in self.to_activity_log(): + f.write(json.dumps(activity, default=str) + "\n") + + # 5. Agent attributions + with open(os.path.join(output_dir, "agents.json"), "w") as f: + json.dump(self.to_agent_attributions(), f, indent=2, default=str) + + # 6. Integrity manifest + with open(os.path.join(output_dir, "integrity_manifest.json"), "w") as f: + json.dump(self.to_integrity_manifest(), f, indent=2, default=str) + + # 7. Croissant metadata + exporter = CroissantExporter(self.graph) + croissant_content = exporter.to_json(name="dataset", url="local://") + with open(os.path.join(output_dir, "croissant.json"), "w") as f: + f.write(croissant_content) + + def summary(self) -> Dict[str, Any]: + """Summary of the accountability bundle.""" + stats = self.graph.stats + return { + "bundle_created_at": self.created_at, + "graph_name": self.graph.name, + "root_hash": self.graph.root_hash, + "entities": stats["entities"], + "activities": stats["activities"], + "agents": stats["agents"], + "relationships": stats["relationships"], + "files_included": [ + "cascade_provenance.json", + "prov_o.jsonld", + "prov_n.txt", + "activities.jsonl", + "agents.json", + "integrity_manifest.json", + "croissant.json", + ], + } + + +class HubIntegration: + """ + Integration with HuggingFace Hub for dataset provenance. + + Stores complete accountability bundle: + 1. cascade_provenance.json - CASCADE native format + 2. prov_o.jsonld - W3C PROV-O JSON-LD (interoperable) + 3. prov_n.txt - W3C PROV-N notation (human readable) + 4. activities.jsonl - Activity log for audit + 5. agents.json - Agent attributions + 6. integrity_manifest.json - Hash verification + 7. croissant.json - MLCommons Croissant + 8. README.md - Human-readable provenance section + """ + + PROVENANCE_FILENAME = "cascade_provenance.json" + PROV_O_FILENAME = "prov_o.jsonld" + PROV_N_FILENAME = "prov_n.txt" + ACTIVITIES_FILENAME = "activities.jsonl" + AGENTS_FILENAME = "agents.json" + INTEGRITY_FILENAME = "integrity_manifest.json" + CROISSANT_FILENAME = "croissant.json" + + def __init__(self, token: str = None): + """ + Initialize Hub integration. + + Args: + token: HuggingFace API token (optional, uses cached token if not provided) + """ + self.token = token + + def push_provenance( + self, + graph: ProvenanceGraph, + repo_id: str, + commit_message: str = "Update provenance", + private: bool = False, + include_croissant: bool = True, + full_accountability: bool = True, + ) -> str: + """ + Push complete accountability bundle to HuggingFace Hub. + + Args: + graph: The provenance graph to push + repo_id: HuggingFace repo ID (e.g., "username/dataset-name") + commit_message: Commit message + private: Whether the repo should be private + include_croissant: Whether to include Croissant JSON-LD + full_accountability: Whether to include full W3C PROV-O bundle + + Returns: + URL of the pushed provenance + """ + from huggingface_hub import HfApi, CommitOperationAdd + + api = HfApi(token=self.token) + + # Ensure repo exists + api.create_repo( + repo_id=repo_id, + repo_type="dataset", + private=private, + exist_ok=True, + ) + + operations = [] + bundle = AccountabilityBundle(graph) + + # 1. CASCADE provenance JSON (native format) + provenance_content = json.dumps(graph.to_dict(), indent=2, default=str) + operations.append(CommitOperationAdd( + path_in_repo=self.PROVENANCE_FILENAME, + path_or_fileobj=provenance_content.encode("utf-8"), + )) + + if full_accountability: + # 2. W3C PROV-O JSON-LD (interoperable standard) + prov_o_content = json.dumps(bundle.to_prov_o_jsonld(), indent=2, default=str) + operations.append(CommitOperationAdd( + path_in_repo=self.PROV_O_FILENAME, + path_or_fileobj=prov_o_content.encode("utf-8"), + )) + + # 3. W3C PROV-N notation (human readable) + prov_n_content = bundle.to_prov_n() + operations.append(CommitOperationAdd( + path_in_repo=self.PROV_N_FILENAME, + path_or_fileobj=prov_n_content.encode("utf-8"), + )) + + # 4. Activity log (JSONL for easy grep/audit) + activities = bundle.to_activity_log() + activities_content = "\n".join(json.dumps(a, default=str) for a in activities) + operations.append(CommitOperationAdd( + path_in_repo=self.ACTIVITIES_FILENAME, + path_or_fileobj=activities_content.encode("utf-8"), + )) + + # 5. Agent attributions + agents_content = json.dumps(bundle.to_agent_attributions(), indent=2, default=str) + operations.append(CommitOperationAdd( + path_in_repo=self.AGENTS_FILENAME, + path_or_fileobj=agents_content.encode("utf-8"), + )) + + # 6. Integrity manifest (for verification) + integrity_content = json.dumps(bundle.to_integrity_manifest(), indent=2, default=str) + operations.append(CommitOperationAdd( + path_in_repo=self.INTEGRITY_FILENAME, + path_or_fileobj=integrity_content.encode("utf-8"), + )) + + # 7. Croissant JSON-LD (MLCommons standard) + if include_croissant: + exporter = CroissantExporter(graph) + croissant_content = exporter.to_json( + name=repo_id.split("/")[-1], + url=f"https://huggingface.co/datasets/{repo_id}", + ) + operations.append(CommitOperationAdd( + path_in_repo=self.CROISSANT_FILENAME, + path_or_fileobj=croissant_content.encode("utf-8"), + )) + + # Commit all accountability artifacts + api.create_commit( + repo_id=repo_id, + repo_type="dataset", + operations=operations, + commit_message=commit_message, + ) + + return f"https://huggingface.co/datasets/{repo_id}" + + def pull_provenance(self, repo_id: str) -> Optional[ProvenanceGraph]: + """ + Pull provenance from HuggingFace Hub. + + Args: + repo_id: HuggingFace repo ID + + Returns: + ProvenanceGraph if found, None otherwise + """ + from huggingface_hub import hf_hub_download + + try: + # Download provenance file + local_path = hf_hub_download( + repo_id=repo_id, + filename=self.PROVENANCE_FILENAME, + repo_type="dataset", + token=self.token, + ) + + with open(local_path, "r", encoding="utf-8") as f: + data = json.load(f) + + return ProvenanceGraph.from_dict(data) + + except Exception as e: + print(f"Could not pull provenance from {repo_id}: {e}") + return None + + def get_dataset_provenance_url(self, repo_id: str) -> str: + """Get URL to provenance file in Hub.""" + return f"https://huggingface.co/datasets/{repo_id}/blob/main/{self.PROVENANCE_FILENAME}" + + def update_dataset_card( + self, + repo_id: str, + graph: ProvenanceGraph, + ) -> str: + """ + Update dataset card with provenance summary. + + Adds/updates YAML front-matter with: + - Lineage information + - Root hash + - Entity/activity counts + + Args: + repo_id: HuggingFace repo ID + graph: Provenance graph + + Returns: + URL of the updated dataset + """ + from huggingface_hub import HfApi, hf_hub_download + + api = HfApi(token=self.token) + + # Build provenance section for README + provenance_section = self._build_readme_section(graph) + + # Get current README + try: + readme_path = hf_hub_download( + repo_id=repo_id, + filename="README.md", + repo_type="dataset", + token=self.token, + ) + with open(readme_path, "r", encoding="utf-8") as f: + current_readme = f.read() + except: + current_readme = f"# {repo_id.split('/')[-1]}\n\n" + + # Update or append provenance section + marker_start = "" + marker_end = "" + + if marker_start in current_readme: + # Replace existing section + import re + pattern = re.escape(marker_start) + r".*?" + re.escape(marker_end) + new_readme = re.sub( + pattern, + f"{marker_start}\n{provenance_section}\n{marker_end}", + current_readme, + flags=re.DOTALL, + ) + else: + # Append section + new_readme = current_readme.rstrip() + f"\n\n{marker_start}\n{provenance_section}\n{marker_end}\n" + + # Push updated README + api.upload_file( + path_or_fileobj=new_readme.encode("utf-8"), + path_in_repo="README.md", + repo_id=repo_id, + repo_type="dataset", + commit_message="Update provenance in README", + ) + + return f"https://huggingface.co/datasets/{repo_id}" + + def _build_readme_section(self, graph: ProvenanceGraph) -> str: + """Build provenance section for README.""" + stats = graph.stats + bundle = AccountabilityBundle(graph) + + lines = [ + "## 🔗 Provenance & Accountability", + "", + "This dataset has CASCADE provenance tracking enabled with full W3C PROV-O compliance.", + "", + "### Integrity", + "", + f"| Metric | Value |", + f"|--------|-------|", + f"| Root Hash | `{graph.root_hash[:16]}...` |", + f"| Entities | {stats['entities']} |", + f"| Activities | {stats['activities']} |", + f"| Agents | {stats['agents']} |", + f"| Relationships | {stats['relationships']} |", + "", + ] + + # Add lineage summary + entities = graph.list_entities() + if entities: + lines.append("### Lineage") + lines.append("") + for entity in entities[:5]: # Show first 5 + upstream = graph.get_lineage(entity.id, "upstream") + if upstream: + lines.append(f"- **{entity.name}** derived from: {', '.join(upstream[:3])}") + else: + lines.append(f"- **{entity.name}** (source)") + if len(entities) > 5: + lines.append(f"- ... and {len(entities) - 5} more entities") + lines.append("") + + # Add activities summary + activities = graph.list_activities() + if activities: + lines.append("### Activities") + lines.append("") + for activity in activities[:5]: + duration = f" ({activity.duration:.2f}s)" if activity.duration else "" + lines.append(f"- **{activity.name}** [{activity.activity_type.value}]{duration}") + if len(activities) > 5: + lines.append(f"- ... and {len(activities) - 5} more activities") + lines.append("") + + # Add agents summary + agents = graph.list_agents() + if agents: + lines.append("### Agents (Accountability)") + lines.append("") + for agent in agents[:5]: + lines.append(f"- **{agent.name}** [{agent.agent_type.value}]") + if len(agents) > 5: + lines.append(f"- ... and {len(agents) - 5} more agents") + lines.append("") + + # Accountability bundle files + lines.extend([ + "### Accountability Bundle", + "", + "| File | Standard | Description |", + "|------|----------|-------------|", + f"| [{self.PROVENANCE_FILENAME}]({self.PROVENANCE_FILENAME}) | CASCADE | Native provenance format |", + f"| [{self.PROV_O_FILENAME}]({self.PROV_O_FILENAME}) | W3C PROV-O | Interoperable JSON-LD |", + f"| [{self.PROV_N_FILENAME}]({self.PROV_N_FILENAME}) | W3C PROV-N | Human-readable notation |", + f"| [{self.ACTIVITIES_FILENAME}]({self.ACTIVITIES_FILENAME}) | JSONL | Activity audit log |", + f"| [{self.AGENTS_FILENAME}]({self.AGENTS_FILENAME}) | JSON | Agent attributions |", + f"| [{self.INTEGRITY_FILENAME}]({self.INTEGRITY_FILENAME}) | JSON | Hash verification manifest |", + f"| [{self.CROISSANT_FILENAME}]({self.CROISSANT_FILENAME}) | MLCommons | Croissant metadata |", + "", + ]) + + return "\n".join(lines) + + +def push_to_hub( + graph: ProvenanceGraph, + repo_id: str, + token: str = None, + private: bool = False, +) -> str: + """ + Convenience function to push provenance to Hub. + + Args: + graph: Provenance graph to push + repo_id: HuggingFace repo ID + token: HF token (optional) + private: Whether repo should be private + + Returns: + URL of the pushed provenance + """ + hub = HubIntegration(token=token) + return hub.push_provenance(graph, repo_id, private=private) + + +def pull_from_hub(repo_id: str, token: str = None) -> Optional[ProvenanceGraph]: + """ + Convenience function to pull provenance from Hub. + + Args: + repo_id: HuggingFace repo ID + token: HF token (optional) + + Returns: + ProvenanceGraph if found + """ + hub = HubIntegration(token=token) + return hub.pull_provenance(repo_id) diff --git a/cascade/data/license.py b/cascade/data/license.py new file mode 100644 index 0000000000000000000000000000000000000000..04bf284bbb2f42d73680f1ba2daa50edc94af423 --- /dev/null +++ b/cascade/data/license.py @@ -0,0 +1,635 @@ +""" +SPDX License Tracking for CASCADE + +Industry standard license tracking based on: +- SPDX (Software Package Data Exchange) - Linux Foundation +- HuggingFace Dataset Cards license field +- Croissant metadata license property + +License Compatibility Rules: +- Permissive (MIT, Apache-2.0) → Can derive into restrictive +- Copyleft (GPL-3.0) → Derivatives must also be copyleft +- NonCommercial (CC-BY-NC-*) → Propagates non-commercial restriction +- ShareAlike (CC-BY-SA-*) → Derivatives must use same license +- NoDerivatives (CC-BY-ND-*) → Cannot create derivatives + +References: +- https://spdx.org/licenses/ +- https://creativecommons.org/licenses/ +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Set, Tuple, Any + + +class LicenseCategory(Enum): + """License categories for compatibility analysis.""" + PERMISSIVE = "permissive" # MIT, Apache, BSD + WEAK_COPYLEFT = "weak-copyleft" # LGPL, MPL + STRONG_COPYLEFT = "strong-copyleft" # GPL, AGPL + CREATIVE_COMMONS = "creative-commons" + PUBLIC_DOMAIN = "public-domain" # CC0, Unlicense + PROPRIETARY = "proprietary" + UNKNOWN = "unknown" + + +class LicenseRestriction(Enum): + """License restrictions that propagate to derivatives.""" + NONE = "none" + ATTRIBUTION = "attribution" # Must credit original + SHARE_ALIKE = "share-alike" # Derivatives same license + NON_COMMERCIAL = "non-commercial" # No commercial use + NO_DERIVATIVES = "no-derivatives" # Cannot modify + COPYLEFT = "copyleft" # Must open source derivatives + + +@dataclass +class SPDXLicense: + """ + SPDX License Information. + + Based on SPDX License List: https://spdx.org/licenses/ + """ + id: str # SPDX identifier (e.g., "MIT", "Apache-2.0") + name: str # Full name + category: LicenseCategory = LicenseCategory.UNKNOWN + restrictions: Set[LicenseRestriction] = field(default_factory=set) + osi_approved: bool = False # Open Source Initiative approved + fsf_libre: bool = False # FSF Free/Libre + url: Optional[str] = None # License text URL + + def allows_commercial(self) -> bool: + """Check if license allows commercial use.""" + return LicenseRestriction.NON_COMMERCIAL not in self.restrictions + + def allows_derivatives(self) -> bool: + """Check if license allows creating derivatives.""" + return LicenseRestriction.NO_DERIVATIVES not in self.restrictions + + def requires_attribution(self) -> bool: + """Check if license requires attribution.""" + return LicenseRestriction.ATTRIBUTION in self.restrictions + + def requires_share_alike(self) -> bool: + """Check if license requires same license for derivatives.""" + return ( + LicenseRestriction.SHARE_ALIKE in self.restrictions or + LicenseRestriction.COPYLEFT in self.restrictions + ) + + def to_dict(self) -> Dict[str, Any]: + return { + "spdx_id": self.id, + "name": self.name, + "category": self.category.value, + "restrictions": [r.value for r in self.restrictions], + "osi_approved": self.osi_approved, + "fsf_libre": self.fsf_libre, + "url": self.url, + } + + +# SPDX License Registry - Common ML/Data licenses +SPDX_LICENSES: Dict[str, SPDXLicense] = { + # Public Domain + "CC0-1.0": SPDXLicense( + id="CC0-1.0", + name="Creative Commons Zero v1.0 Universal", + category=LicenseCategory.PUBLIC_DOMAIN, + restrictions=set(), + osi_approved=False, + fsf_libre=True, + url="https://creativecommons.org/publicdomain/zero/1.0/", + ), + "Unlicense": SPDXLicense( + id="Unlicense", + name="The Unlicense", + category=LicenseCategory.PUBLIC_DOMAIN, + restrictions=set(), + osi_approved=True, + fsf_libre=True, + url="https://unlicense.org/", + ), + + # Permissive + "MIT": SPDXLicense( + id="MIT", + name="MIT License", + category=LicenseCategory.PERMISSIVE, + restrictions={LicenseRestriction.ATTRIBUTION}, + osi_approved=True, + fsf_libre=True, + url="https://opensource.org/licenses/MIT", + ), + "Apache-2.0": SPDXLicense( + id="Apache-2.0", + name="Apache License 2.0", + category=LicenseCategory.PERMISSIVE, + restrictions={LicenseRestriction.ATTRIBUTION}, + osi_approved=True, + fsf_libre=True, + url="https://www.apache.org/licenses/LICENSE-2.0", + ), + "BSD-2-Clause": SPDXLicense( + id="BSD-2-Clause", + name='BSD 2-Clause "Simplified" License', + category=LicenseCategory.PERMISSIVE, + restrictions={LicenseRestriction.ATTRIBUTION}, + osi_approved=True, + fsf_libre=True, + url="https://opensource.org/licenses/BSD-2-Clause", + ), + "BSD-3-Clause": SPDXLicense( + id="BSD-3-Clause", + name='BSD 3-Clause "New" or "Revised" License', + category=LicenseCategory.PERMISSIVE, + restrictions={LicenseRestriction.ATTRIBUTION}, + osi_approved=True, + fsf_libre=True, + url="https://opensource.org/licenses/BSD-3-Clause", + ), + + # Creative Commons + "CC-BY-4.0": SPDXLicense( + id="CC-BY-4.0", + name="Creative Commons Attribution 4.0", + category=LicenseCategory.CREATIVE_COMMONS, + restrictions={LicenseRestriction.ATTRIBUTION}, + osi_approved=False, + fsf_libre=True, + url="https://creativecommons.org/licenses/by/4.0/", + ), + "CC-BY-SA-4.0": SPDXLicense( + id="CC-BY-SA-4.0", + name="Creative Commons Attribution ShareAlike 4.0", + category=LicenseCategory.CREATIVE_COMMONS, + restrictions={LicenseRestriction.ATTRIBUTION, LicenseRestriction.SHARE_ALIKE}, + osi_approved=False, + fsf_libre=True, + url="https://creativecommons.org/licenses/by-sa/4.0/", + ), + "CC-BY-NC-4.0": SPDXLicense( + id="CC-BY-NC-4.0", + name="Creative Commons Attribution NonCommercial 4.0", + category=LicenseCategory.CREATIVE_COMMONS, + restrictions={LicenseRestriction.ATTRIBUTION, LicenseRestriction.NON_COMMERCIAL}, + osi_approved=False, + fsf_libre=False, + url="https://creativecommons.org/licenses/by-nc/4.0/", + ), + "CC-BY-NC-SA-4.0": SPDXLicense( + id="CC-BY-NC-SA-4.0", + name="Creative Commons Attribution NonCommercial ShareAlike 4.0", + category=LicenseCategory.CREATIVE_COMMONS, + restrictions={ + LicenseRestriction.ATTRIBUTION, + LicenseRestriction.NON_COMMERCIAL, + LicenseRestriction.SHARE_ALIKE, + }, + osi_approved=False, + fsf_libre=False, + url="https://creativecommons.org/licenses/by-nc-sa/4.0/", + ), + "CC-BY-ND-4.0": SPDXLicense( + id="CC-BY-ND-4.0", + name="Creative Commons Attribution NoDerivatives 4.0", + category=LicenseCategory.CREATIVE_COMMONS, + restrictions={LicenseRestriction.ATTRIBUTION, LicenseRestriction.NO_DERIVATIVES}, + osi_approved=False, + fsf_libre=False, + url="https://creativecommons.org/licenses/by-nd/4.0/", + ), + + # Weak Copyleft + "LGPL-3.0": SPDXLicense( + id="LGPL-3.0", + name="GNU Lesser General Public License v3.0", + category=LicenseCategory.WEAK_COPYLEFT, + restrictions={LicenseRestriction.ATTRIBUTION, LicenseRestriction.COPYLEFT}, + osi_approved=True, + fsf_libre=True, + url="https://www.gnu.org/licenses/lgpl-3.0.html", + ), + "MPL-2.0": SPDXLicense( + id="MPL-2.0", + name="Mozilla Public License 2.0", + category=LicenseCategory.WEAK_COPYLEFT, + restrictions={LicenseRestriction.ATTRIBUTION, LicenseRestriction.COPYLEFT}, + osi_approved=True, + fsf_libre=True, + url="https://www.mozilla.org/en-US/MPL/2.0/", + ), + + # Strong Copyleft + "GPL-3.0": SPDXLicense( + id="GPL-3.0", + name="GNU General Public License v3.0", + category=LicenseCategory.STRONG_COPYLEFT, + restrictions={LicenseRestriction.ATTRIBUTION, LicenseRestriction.COPYLEFT}, + osi_approved=True, + fsf_libre=True, + url="https://www.gnu.org/licenses/gpl-3.0.html", + ), + "AGPL-3.0": SPDXLicense( + id="AGPL-3.0", + name="GNU Affero General Public License v3.0", + category=LicenseCategory.STRONG_COPYLEFT, + restrictions={LicenseRestriction.ATTRIBUTION, LicenseRestriction.COPYLEFT}, + osi_approved=True, + fsf_libre=True, + url="https://www.gnu.org/licenses/agpl-3.0.html", + ), + + # ML-Specific + "OpenRAIL": SPDXLicense( + id="OpenRAIL", + name="Open RAIL License", + category=LicenseCategory.PERMISSIVE, + restrictions={LicenseRestriction.ATTRIBUTION}, + osi_approved=False, + fsf_libre=False, + url="https://huggingface.co/blog/open_rail", + ), + "OpenRAIL-M": SPDXLicense( + id="OpenRAIL-M", + name="Open RAIL-M License", + category=LicenseCategory.PERMISSIVE, + restrictions={LicenseRestriction.ATTRIBUTION}, + osi_approved=False, + fsf_libre=False, + url="https://www.licenses.ai/blog/2022/8/26/bigscience-open-rail-m-license", + ), + + # Special + "other": SPDXLicense( + id="other", + name="Other/Custom License", + category=LicenseCategory.UNKNOWN, + restrictions=set(), + osi_approved=False, + fsf_libre=False, + url=None, + ), + "unknown": SPDXLicense( + id="unknown", + name="Unknown License", + category=LicenseCategory.UNKNOWN, + restrictions=set(), + osi_approved=False, + fsf_libre=False, + url=None, + ), +} + + +def get_license(spdx_id: str) -> SPDXLicense: + """ + Get license by SPDX identifier. + + Args: + spdx_id: SPDX license identifier (case-insensitive) + + Returns: + SPDXLicense object (unknown if not found) + """ + # Normalize common variants + normalized = spdx_id.strip() + + # Direct lookup + if normalized in SPDX_LICENSES: + return SPDX_LICENSES[normalized] + + # Case-insensitive lookup + for key, license in SPDX_LICENSES.items(): + if key.lower() == normalized.lower(): + return license + + # Common aliases + aliases = { + "mit": "MIT", + "apache": "Apache-2.0", + "apache2": "Apache-2.0", + "gpl": "GPL-3.0", + "gpl3": "GPL-3.0", + "lgpl": "LGPL-3.0", + "bsd": "BSD-3-Clause", + "cc0": "CC0-1.0", + "cc-by": "CC-BY-4.0", + "cc-by-sa": "CC-BY-SA-4.0", + "cc-by-nc": "CC-BY-NC-4.0", + "cc-by-nc-sa": "CC-BY-NC-SA-4.0", + "cc-by-nd": "CC-BY-ND-4.0", + "unlicense": "Unlicense", + "public domain": "CC0-1.0", + "openrail": "OpenRAIL", + } + + lower_id = normalized.lower().replace("_", "-").replace(" ", "-") + if lower_id in aliases: + return SPDX_LICENSES[aliases[lower_id]] + + # Return unknown + return SPDX_LICENSES["unknown"] + + +@dataclass +class LicenseCompatibility: + """Result of license compatibility check.""" + compatible: bool + derived_license: Optional[SPDXLicense] = None + issues: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + attribution_required: List[str] = field(default_factory=list) # Source IDs requiring attribution + + +class LicenseAnalyzer: + """ + Analyze license compatibility for dataset derivation. + + Rules: + 1. No-Derivatives: Cannot create derivatives + 2. Share-Alike: Must use same license + 3. Copyleft: Must use compatible copyleft license + 4. Non-Commercial: Restriction propagates + 5. Attribution: Must credit all sources + """ + + # License compatibility matrix (can this → derive into that?) + # Rows: source license category, Columns: derived license category + COMPATIBILITY_MATRIX = { + LicenseCategory.PUBLIC_DOMAIN: { + LicenseCategory.PUBLIC_DOMAIN: True, + LicenseCategory.PERMISSIVE: True, + LicenseCategory.CREATIVE_COMMONS: True, + LicenseCategory.WEAK_COPYLEFT: True, + LicenseCategory.STRONG_COPYLEFT: True, + LicenseCategory.PROPRIETARY: True, + }, + LicenseCategory.PERMISSIVE: { + LicenseCategory.PUBLIC_DOMAIN: False, + LicenseCategory.PERMISSIVE: True, + LicenseCategory.CREATIVE_COMMONS: True, + LicenseCategory.WEAK_COPYLEFT: True, + LicenseCategory.STRONG_COPYLEFT: True, + LicenseCategory.PROPRIETARY: True, + }, + LicenseCategory.CREATIVE_COMMONS: { + LicenseCategory.PUBLIC_DOMAIN: False, + LicenseCategory.PERMISSIVE: False, # Depends on specific CC + LicenseCategory.CREATIVE_COMMONS: True, # Depends on specific CC + LicenseCategory.WEAK_COPYLEFT: False, + LicenseCategory.STRONG_COPYLEFT: False, + LicenseCategory.PROPRIETARY: False, + }, + LicenseCategory.WEAK_COPYLEFT: { + LicenseCategory.PUBLIC_DOMAIN: False, + LicenseCategory.PERMISSIVE: False, + LicenseCategory.CREATIVE_COMMONS: False, + LicenseCategory.WEAK_COPYLEFT: True, + LicenseCategory.STRONG_COPYLEFT: True, + LicenseCategory.PROPRIETARY: False, + }, + LicenseCategory.STRONG_COPYLEFT: { + LicenseCategory.PUBLIC_DOMAIN: False, + LicenseCategory.PERMISSIVE: False, + LicenseCategory.CREATIVE_COMMONS: False, + LicenseCategory.WEAK_COPYLEFT: False, + LicenseCategory.STRONG_COPYLEFT: True, + LicenseCategory.PROPRIETARY: False, + }, + } + + def check_compatibility( + self, + source_licenses: List[Tuple[str, str]], # List of (entity_id, spdx_id) + target_license: Optional[str] = None, + ) -> LicenseCompatibility: + """ + Check if source licenses allow derivation. + + Args: + source_licenses: List of (entity_id, license_id) tuples + target_license: Intended license for derived work (optional) + + Returns: + LicenseCompatibility result + """ + if not source_licenses: + return LicenseCompatibility( + compatible=True, + derived_license=SPDX_LICENSES["unknown"], + ) + + issues = [] + warnings = [] + attribution_required = [] + + # Collect all restrictions + all_restrictions: Set[LicenseRestriction] = set() + licenses = [] + + for entity_id, spdx_id in source_licenses: + lic = get_license(spdx_id) + licenses.append((entity_id, lic)) + all_restrictions.update(lic.restrictions) + + # Track attribution requirements + if lic.requires_attribution(): + attribution_required.append(entity_id) + + # Check No-Derivatives + for entity_id, lic in licenses: + if LicenseRestriction.NO_DERIVATIVES in lic.restrictions: + issues.append( + f"Cannot derive from '{entity_id}': license '{lic.id}' prohibits derivatives" + ) + + if issues: + return LicenseCompatibility( + compatible=False, + issues=issues, + warnings=warnings, + attribution_required=attribution_required, + ) + + # Determine derived license + derived = self._compute_derived_license(licenses, all_restrictions) + + # Check target license compatibility + if target_license: + target = get_license(target_license) + if not self._can_relicense(derived, target): + issues.append( + f"Cannot license derived work as '{target.id}': " + f"must use '{derived.id}' or compatible license" + ) + + # Add warnings + if LicenseRestriction.NON_COMMERCIAL in all_restrictions: + warnings.append("Derived work restricted to non-commercial use only") + + if LicenseRestriction.SHARE_ALIKE in all_restrictions: + warnings.append(f"Derived work must use ShareAlike-compatible license: {derived.id}") + + if LicenseRestriction.COPYLEFT in all_restrictions: + warnings.append(f"Derived work must use copyleft license: {derived.id}") + + return LicenseCompatibility( + compatible=len(issues) == 0, + derived_license=derived, + issues=issues, + warnings=warnings, + attribution_required=attribution_required, + ) + + def _compute_derived_license( + self, + licenses: List[Tuple[str, SPDXLicense]], + all_restrictions: Set[LicenseRestriction], + ) -> SPDXLicense: + """ + Compute the most restrictive license for derived work. + + The derived license is the "lowest common denominator" that + satisfies all source license requirements. + """ + # Priority: Strong Copyleft > Weak Copyleft > CC-SA > CC-NC > Permissive > Public Domain + + has_strong_copyleft = any( + lic.category == LicenseCategory.STRONG_COPYLEFT + for _, lic in licenses + ) + has_weak_copyleft = any( + lic.category == LicenseCategory.WEAK_COPYLEFT + for _, lic in licenses + ) + has_share_alike = LicenseRestriction.SHARE_ALIKE in all_restrictions + has_non_commercial = LicenseRestriction.NON_COMMERCIAL in all_restrictions + + # Strong copyleft dominates + if has_strong_copyleft: + for _, lic in licenses: + if lic.category == LicenseCategory.STRONG_COPYLEFT: + return lic + + # Weak copyleft next + if has_weak_copyleft: + for _, lic in licenses: + if lic.category == LicenseCategory.WEAK_COPYLEFT: + return lic + + # CC with restrictions + if has_share_alike and has_non_commercial: + return SPDX_LICENSES["CC-BY-NC-SA-4.0"] + elif has_share_alike: + return SPDX_LICENSES["CC-BY-SA-4.0"] + elif has_non_commercial: + return SPDX_LICENSES["CC-BY-NC-4.0"] + + # Most permissive with attribution + if LicenseRestriction.ATTRIBUTION in all_restrictions: + # Check if any source requires specific license + for _, lic in licenses: + if lic.category == LicenseCategory.CREATIVE_COMMONS: + return lic + return SPDX_LICENSES["CC-BY-4.0"] + + # Public domain + return SPDX_LICENSES["CC0-1.0"] + + def _can_relicense(self, source: SPDXLicense, target: SPDXLicense) -> bool: + """Check if source license allows relicensing to target.""" + # Same license is always OK + if source.id == target.id: + return True + + # No relicensing from copyleft to non-copyleft + if LicenseRestriction.COPYLEFT in source.restrictions: + if LicenseRestriction.COPYLEFT not in target.restrictions: + return False + + # No relicensing from share-alike to non-share-alike + if LicenseRestriction.SHARE_ALIKE in source.restrictions: + if LicenseRestriction.SHARE_ALIKE not in target.restrictions: + return False + + # Non-commercial must propagate + if LicenseRestriction.NON_COMMERCIAL in source.restrictions: + if LicenseRestriction.NON_COMMERCIAL not in target.restrictions: + return False + + return True + + def generate_attribution( + self, + sources: List[Tuple[str, str, str]], # (entity_id, license_id, name) + ) -> str: + """ + Generate attribution text for derived work. + + Args: + sources: List of (entity_id, license_id, name) tuples + + Returns: + Attribution text + """ + lines = [ + "## Attribution", + "", + "This dataset is derived from the following sources:", + "", + ] + + for entity_id, license_id, name in sources: + lic = get_license(license_id) + if lic.requires_attribution(): + line = f"- **{name}** (`{entity_id}`)" + if lic.url: + line += f" - Licensed under [{lic.id}]({lic.url})" + else: + line += f" - Licensed under {lic.id}" + lines.append(line) + + if len(lines) == 4: # No attributions needed + return "" + + lines.append("") + return "\n".join(lines) + + +# Singleton analyzer +_analyzer = LicenseAnalyzer() + + +def check_license_compatibility( + sources: List[Tuple[str, str]], + target: Optional[str] = None, +) -> LicenseCompatibility: + """ + Convenience function to check license compatibility. + + Args: + sources: List of (entity_id, license_id) tuples + target: Intended license for derived work + + Returns: + LicenseCompatibility result + """ + return _analyzer.check_compatibility(sources, target) + + +def get_derived_license(sources: List[str]) -> SPDXLicense: + """ + Get the appropriate license for a work derived from given sources. + + Args: + sources: List of SPDX license identifiers + + Returns: + SPDXLicense for the derived work + """ + result = _analyzer.check_compatibility([ + (f"source_{i}", lic) for i, lic in enumerate(sources) + ]) + return result.derived_license or SPDX_LICENSES["unknown"] diff --git a/cascade/data/live.py b/cascade/data/live.py new file mode 100644 index 0000000000000000000000000000000000000000..00b201a5853ce97c9bf5dddd528b9b8ca54fe0a3 --- /dev/null +++ b/cascade/data/live.py @@ -0,0 +1,844 @@ +""" +Live Document Tracer + +Real-time streaming of document-centric provenance events. +This is the LIVE version of what the export system freezes. + +Instead of: Model runs → Process → Export frozen provenance +We do: Model runs → STREAM events → View live document highlights + +Same data model as the observer/exporter, just streamed in real-time +with document snippet context attached. + +Usage: + # Create observer with live streaming + observer = DatasetObserver("my_pipeline") + tracer = LiveDocumentTracer(observer) + + # Subscribe to events + tracer.on_event(my_handler) + + # Or stream to async consumer + async for event in tracer.stream(): + render_highlight(event) +""" + +import asyncio +import json +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple +from queue import Queue +from threading import Lock +from pathlib import Path + + +class TraceEventType(Enum): + """Types of document trace events.""" + # Data flow events + DOCUMENT_TOUCHED = "document_touched" # Model accessed this document/record + SPAN_HIGHLIGHTED = "span_highlighted" # Specific text span being processed + ASSOCIATION_CREATED = "association_created" # Link between two spans/documents + + # Activity events + ACTIVITY_STARTED = "activity_started" + ACTIVITY_PROGRESS = "activity_progress" + ACTIVITY_COMPLETED = "activity_completed" + + # Entity events + ENTITY_CREATED = "entity_created" + ENTITY_DERIVED = "entity_derived" + + # Relationship events + LINK_CREATED = "link_created" + + +@dataclass +class DocumentSpan: + """ + A span within a document being traced. + + This is the atomic unit of live visualization - + the specific text/content the model is touching. + """ + document_id: str # Entity or record ID + document_name: str # Human-readable name + field_name: str = "" # Column/field if applicable + row_index: int = -1 # Row if applicable + + # The actual content span + text: str = "" # The snippet text + start_char: int = -1 # Start position in full text + end_char: int = -1 # End position in full text + + # Visual hints + highlight_type: str = "default" # "source", "target", "match", "attention" + confidence: float = 1.0 # For attention/relevance visualization + + # Metadata + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "document_id": self.document_id, + "document_name": self.document_name, + "field_name": self.field_name, + "row_index": self.row_index, + "text": self.text, + "start_char": self.start_char, + "end_char": self.end_char, + "highlight_type": self.highlight_type, + "confidence": self.confidence, + "metadata": self.metadata, + } + + +@dataclass +class DocumentAssociation: + """ + An association between two document spans. + + Represents the model saying "this connects to that". + """ + source: DocumentSpan + target: DocumentSpan + association_type: str = "related" # "match", "derived", "similar", "references" + confidence: float = 1.0 + + # Why this association was made + reason: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "source": self.source.to_dict(), + "target": self.target.to_dict(), + "association_type": self.association_type, + "confidence": self.confidence, + "reason": self.reason, + } + + +@dataclass +class TraceEvent: + """ + A single trace event for live document visualization. + + This is what gets streamed to the UI in real-time. + """ + event_type: TraceEventType + timestamp: float = field(default_factory=time.time) + + # Activity context + activity_id: Optional[str] = None + activity_name: Optional[str] = None + activity_type: Optional[str] = None + + # Document spans involved + spans: List[DocumentSpan] = field(default_factory=list) + + # Association if this event creates one + association: Optional[DocumentAssociation] = None + + # Progress for long operations + progress: Optional[float] = None # 0.0 to 1.0 + progress_message: Optional[str] = None + + # Raw provenance data (for export compatibility) + entity_id: Optional[str] = None + relationship_type: Optional[str] = None + + # Metadata + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "event_type": self.event_type.value, + "timestamp": self.timestamp, + "activity_id": self.activity_id, + "activity_name": self.activity_name, + "activity_type": self.activity_type, + "spans": [s.to_dict() for s in self.spans], + "association": self.association.to_dict() if self.association else None, + "progress": self.progress, + "progress_message": self.progress_message, + "entity_id": self.entity_id, + "metadata": self.metadata, + } + + def to_json(self) -> str: + return json.dumps(self.to_dict(), default=str) + + +class LiveDocumentTracer: + """ + Real-time document tracing for live visualization. + + Hooks into DatasetObserver to stream events as they happen, + enriched with document snippet context for visualization. + + This is the LIVE version of what CroissantExporter freezes. + + NEW: Now writes all events to a tape file (JSONL) for buffered playback! + """ + + def __init__(self, observer=None, buffer_size: int = 1000, log_dir: str = "./logs"): + """ + Initialize tracer. + + Args: + observer: DatasetObserver to hook into (optional) + buffer_size: Max events to buffer for replay + log_dir: Directory for tape files (JSONL logs) + """ + self.observer = observer + self.buffer_size = buffer_size + + # Event subscribers + self._handlers: List[Callable[[TraceEvent], None]] = [] + self._async_handlers: List[Callable[[TraceEvent], Any]] = [] + + # Event buffer for replay/late subscribers + self._buffer: List[TraceEvent] = [] + self._buffer_lock = Lock() + + # Async queue for streaming + self._async_queue: Optional[asyncio.Queue] = None + + # Current activity context + self._current_activity_id: Optional[str] = None + self._current_activity_name: Optional[str] = None + self._current_activity_type: Optional[str] = None + + # Document context cache + self._document_cache: Dict[str, Dict[str, Any]] = {} + + # === TAPE FILE FOR PLAYBACK === + self._log_dir = Path(log_dir) + self._log_dir.mkdir(parents=True, exist_ok=True) + self._session_id = int(time.time()) + self._tape_path = self._log_dir / f"unity_tape_{self._session_id}.jsonl" + self._tape_file = None + self._tape_lock = Lock() + self._event_count = 0 + + # ═══════════════════════════════════════════════════════════════════════════ + # SUBSCRIPTION + # ═══════════════════════════════════════════════════════════════════════════ + + def on_event(self, handler: Callable[[TraceEvent], None]): + """Subscribe to trace events (sync handler).""" + self._handlers.append(handler) + return self # Allow chaining + + def on_event_async(self, handler: Callable[[TraceEvent], Any]): + """Subscribe to trace events (async handler).""" + self._async_handlers.append(handler) + return self + + def remove_handler(self, handler): + """Unsubscribe a handler.""" + if handler in self._handlers: + self._handlers.remove(handler) + if handler in self._async_handlers: + self._async_handlers.remove(handler) + + # ═══════════════════════════════════════════════════════════════════════════ + # EVENT EMISSION + # ═══════════════════════════════════════════════════════════════════════════ + + def emit(self, event: TraceEvent): + """ + Emit a trace event to all subscribers. + + Called internally when provenance events occur. + Also writes to tape file for buffered playback! + """ + self._event_count += 1 + + # Add to buffer + with self._buffer_lock: + self._buffer.append(event) + if len(self._buffer) > self.buffer_size: + self._buffer.pop(0) + + # === WRITE TO TAPE (JSONL) === + self._write_to_tape(event) + + # Call sync handlers + for handler in self._handlers: + try: + handler(event) + except Exception as e: + print(f"Handler error: {e}") + + # Queue for async handlers + if self._async_queue: + try: + self._async_queue.put_nowait(event) + except asyncio.QueueFull: + pass # Drop if queue full + + def _write_to_tape(self, event: TraceEvent): + """Write event to tape file for later playback.""" + try: + with self._tape_lock: + # Lazy open the file + if self._tape_file is None: + self._tape_file = open(self._tape_path, "a", encoding="utf-8") + print(f"[CASCADE] 📼 Unity tape started: {self._tape_path}") + + # Build tape record with full context + record = { + "seq": self._event_count, + "event": event.to_dict(), + "session_id": self._session_id, + } + + json_line = json.dumps(record, default=str) + "\n" + self._tape_file.write(json_line) + self._tape_file.flush() + + # Debug: Log first few events + if self._event_count <= 3: + print(f"[CASCADE] 📝 Wrote event {self._event_count} to tape: {event.event_type}") + except Exception as e: + # Don't let tape errors break the main flow + print(f"[CASCADE] ⚠️ Tape write error: {e}") + pass + + def _write_raw_to_tape(self, record: Dict[str, Any]): + """Write a raw record to tape file (for docspace events).""" + try: + with self._tape_lock: + # Lazy open the file + if self._tape_file is None: + self._tape_file = open(self._tape_path, "a", encoding="utf-8") + print(f"[CASCADE] 📼 Unity tape started: {self._tape_path}") + + self._tape_file.write(json.dumps(record, default=str) + "\n") + self._tape_file.flush() + except Exception: + pass + + # ═══════════════════════════════════════════════════════════════════════════ + # DOCUMENT SPACE EVENTS (for polling iframe) + # ═══════════════════════════════════════════════════════════════════════════ + + def emit_entity(self, entity_id: str, source: str, text: str, index: int, side: str = "a"): + """ + Emit an entity for Document Space visualization. + + Args: + entity_id: Unique ID for the entity + source: Source dataset name + text: Preview text (truncated) + index: Row index in dataset + side: "a" or "b" to indicate which dataset + """ + self._event_count += 1 + record = { + "seq": self._event_count, + "type": "docspace_entity", + "side": side, + "data": { + "id": entity_id, + "source": source, + "text": text[:200], + "index": index, + }, + "session_id": self._session_id, + } + self._write_raw_to_tape(record) + + def emit_match(self, doc_a_id: str, doc_b_id: str, score: float): + """ + Emit a match for Document Space visualization. + + Args: + doc_a_id: ID of entity from dataset A + doc_b_id: ID of entity from dataset B + score: Similarity score (0-1) + """ + self._event_count += 1 + record = { + "seq": self._event_count, + "type": "docspace_match", + "data": { + "docA": doc_a_id, + "docB": doc_b_id, + "score": float(score), + }, + "session_id": self._session_id, + } + self._write_raw_to_tape(record) + + def emit_phase(self, phase: str, progress: float, message: str = ""): + """ + Emit a phase update for Document Space. + + Args: + phase: Current phase (embedding_a, embedding_b, comparing, complete) + progress: Progress 0-1 + message: Status message + """ + self._event_count += 1 + record = { + "seq": self._event_count, + "type": "docspace_phase", + "data": { + "phase": phase, + "progress": float(progress), + "message": message, + }, + "session_id": self._session_id, + } + self._write_raw_to_tape(record) + + def close_tape(self): + """Close the tape file (call when session ends).""" + with self._tape_lock: + if self._tape_file: + self._tape_file.close() + self._tape_file = None + print(f"[CASCADE] 📼 Unity tape closed: {self._event_count} events → {self._tape_path}") + + def get_tape_path(self) -> Optional[Path]: + """Get the path to the current tape file (whether open or not).""" + return self._tape_path + + @staticmethod + def load_tape(tape_path: str) -> List[Dict[str, Any]]: + """ + Load events from a tape file for playback. + + Args: + tape_path: Path to the .jsonl tape file + + Returns: + List of event records in chronological order + """ + events = [] + with open(tape_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + try: + events.append(json.loads(line)) + except json.JSONDecodeError: + pass # Skip malformed lines + return events + + async def stream(self) -> Generator[TraceEvent, None, None]: + """ + Async generator for streaming events. + + Usage: + async for event in tracer.stream(): + await render(event) + """ + self._async_queue = asyncio.Queue(maxsize=self.buffer_size) + + # Replay buffer first + with self._buffer_lock: + for event in self._buffer: + yield event + + # Then stream new events + while True: + event = await self._async_queue.get() + yield event + + def get_buffer(self) -> List[TraceEvent]: + """Get buffered events for replay.""" + with self._buffer_lock: + return list(self._buffer) + + def clear_buffer(self): + """Clear the event buffer.""" + with self._buffer_lock: + self._buffer.clear() + + # ═══════════════════════════════════════════════════════════════════════════ + # TRACING API - Call these to emit events + # ═══════════════════════════════════════════════════════════════════════════ + + def start_activity( + self, + activity_id: str, + activity_name: str, + activity_type: str = "transform", + ): + """Signal start of an activity (for context).""" + self._current_activity_id = activity_id + self._current_activity_name = activity_name + self._current_activity_type = activity_type + + self.emit(TraceEvent( + event_type=TraceEventType.ACTIVITY_STARTED, + activity_id=activity_id, + activity_name=activity_name, + activity_type=activity_type, + )) + + def end_activity(self, activity_id: str = None): + """Signal end of an activity.""" + self.emit(TraceEvent( + event_type=TraceEventType.ACTIVITY_COMPLETED, + activity_id=activity_id or self._current_activity_id, + activity_name=self._current_activity_name, + activity_type=self._current_activity_type, + )) + self._current_activity_id = None + self._current_activity_name = None + self._current_activity_type = None + + def report_progress( + self, + progress: float, + message: str = "", + activity_id: str = None, + ): + """Report progress on current activity.""" + self.emit(TraceEvent( + event_type=TraceEventType.ACTIVITY_PROGRESS, + activity_id=activity_id or self._current_activity_id, + activity_name=self._current_activity_name, + progress=progress, + progress_message=message, + )) + + def touch_document( + self, + document_id: str, + document_name: str, + snippet: str = "", + field_name: str = "", + row_index: int = -1, + highlight_type: str = "default", + confidence: float = 1.0, + **metadata, + ): + """ + Signal that the model touched a document/record. + + This creates a highlight in the live view. + """ + span = DocumentSpan( + document_id=document_id, + document_name=document_name, + field_name=field_name, + row_index=row_index, + text=snippet, + highlight_type=highlight_type, + confidence=confidence, + metadata=metadata, + ) + + self.emit(TraceEvent( + event_type=TraceEventType.DOCUMENT_TOUCHED, + activity_id=self._current_activity_id, + activity_name=self._current_activity_name, + activity_type=self._current_activity_type, + spans=[span], + entity_id=document_id, + metadata=metadata, + )) + + return span + + def highlight_span( + self, + document_id: str, + document_name: str, + text: str, + start_char: int = -1, + end_char: int = -1, + field_name: str = "", + row_index: int = -1, + highlight_type: str = "attention", + confidence: float = 1.0, + **metadata, + ): + """ + Highlight a specific span within a document. + + For showing exactly where in the text the model is focusing. + """ + span = DocumentSpan( + document_id=document_id, + document_name=document_name, + field_name=field_name, + row_index=row_index, + text=text, + start_char=start_char, + end_char=end_char, + highlight_type=highlight_type, + confidence=confidence, + metadata=metadata, + ) + + self.emit(TraceEvent( + event_type=TraceEventType.SPAN_HIGHLIGHTED, + activity_id=self._current_activity_id, + activity_name=self._current_activity_name, + activity_type=self._current_activity_type, + spans=[span], + metadata=metadata, + )) + + return span + + def create_association( + self, + source_doc_id: str, + source_doc_name: str, + source_text: str, + target_doc_id: str, + target_doc_name: str, + target_text: str, + association_type: str = "related", + confidence: float = 1.0, + reason: str = "", + **metadata, + ): + """ + Create an association between two document spans. + + This is the "A connects to B" visualization. + """ + source = DocumentSpan( + document_id=source_doc_id, + document_name=source_doc_name, + text=source_text, + highlight_type="source", + confidence=confidence, + ) + + target = DocumentSpan( + document_id=target_doc_id, + document_name=target_doc_name, + text=target_text, + highlight_type="target", + confidence=confidence, + ) + + association = DocumentAssociation( + source=source, + target=target, + association_type=association_type, + confidence=confidence, + reason=reason, + ) + + self.emit(TraceEvent( + event_type=TraceEventType.ASSOCIATION_CREATED, + activity_id=self._current_activity_id, + activity_name=self._current_activity_name, + activity_type=self._current_activity_type, + spans=[source, target], + association=association, + metadata=metadata, + )) + + return association + + def entity_created( + self, + entity_id: str, + entity_name: str, + record_count: int = None, + **metadata, + ): + """Signal that a new entity was created in provenance.""" + self.emit(TraceEvent( + event_type=TraceEventType.ENTITY_CREATED, + activity_id=self._current_activity_id, + activity_name=self._current_activity_name, + entity_id=entity_id, + metadata={"name": entity_name, "record_count": record_count, **metadata}, + )) + + def entity_derived( + self, + derived_id: str, + derived_name: str, + source_ids: List[str], + **metadata, + ): + """Signal that an entity was derived from others.""" + self.emit(TraceEvent( + event_type=TraceEventType.ENTITY_DERIVED, + activity_id=self._current_activity_id, + activity_name=self._current_activity_name, + entity_id=derived_id, + metadata={"name": derived_name, "sources": source_ids, **metadata}, + )) + + def link_created( + self, + source_id: str, + target_id: str, + relationship_type: str, + **metadata, + ): + """Signal that a provenance link was created.""" + self.emit(TraceEvent( + event_type=TraceEventType.LINK_CREATED, + activity_id=self._current_activity_id, + activity_name=self._current_activity_name, + relationship_type=relationship_type, + metadata={"source": source_id, "target": target_id, **metadata}, + )) + + # ═══════════════════════════════════════════════════════════════════════════ + # EXPORT (Freeze the live state) + # ═══════════════════════════════════════════════════════════════════════════ + + def export_session(self) -> Dict[str, Any]: + """ + Export the trace session as frozen data. + + This is the bridge between live and export - + same data, just frozen at a point in time. + """ + with self._buffer_lock: + return { + "events": [e.to_dict() for e in self._buffer], + "event_count": len(self._buffer), + "exported_at": time.time(), + } + + def export_associations(self) -> List[Dict[str, Any]]: + """Export just the associations for visualization.""" + associations = [] + with self._buffer_lock: + for event in self._buffer: + if event.association: + associations.append(event.association.to_dict()) + return associations + + def export_timeline(self) -> List[Dict[str, Any]]: + """Export events as a timeline.""" + timeline = [] + with self._buffer_lock: + for event in self._buffer: + timeline.append({ + "timestamp": event.timestamp, + "type": event.event_type.value, + "activity": event.activity_name, + "spans": len(event.spans), + "has_association": event.association is not None, + }) + return timeline + + +# ═══════════════════════════════════════════════════════════════════════════════ +# CONSOLE RENDERER - Simple text-based live view +# ═══════════════════════════════════════════════════════════════════════════════ + +class ConsoleTraceRenderer: + """ + Simple console renderer for live document traces. + + Good for debugging and terminal-based workflows. + """ + + def __init__(self, show_snippets: bool = True, max_snippet_len: int = 80): + self.show_snippets = show_snippets + self.max_snippet_len = max_snippet_len + + def render(self, event: TraceEvent): + """Render event to console.""" + timestamp = time.strftime("%H:%M:%S", time.localtime(event.timestamp)) + + if event.event_type == TraceEventType.ACTIVITY_STARTED: + print(f"\n[{timestamp}] ▶ {event.activity_name} ({event.activity_type})") + print("─" * 60) + + elif event.event_type == TraceEventType.ACTIVITY_COMPLETED: + print("─" * 60) + print(f"[{timestamp}] ✓ {event.activity_name} completed") + + elif event.event_type == TraceEventType.ACTIVITY_PROGRESS: + pct = int((event.progress or 0) * 100) + bar = "█" * (pct // 5) + "░" * (20 - pct // 5) + msg = event.progress_message or "" + print(f"\r[{timestamp}] [{bar}] {pct}% {msg}", end="", flush=True) + if pct >= 100: + print() + + elif event.event_type == TraceEventType.DOCUMENT_TOUCHED: + for span in event.spans: + snippet = self._truncate(span.text) + print(f"[{timestamp}] 📄 {span.document_name}", end="") + if span.field_name: + print(f"[{span.field_name}]", end="") + if span.row_index >= 0: + print(f" row={span.row_index}", end="") + if self.show_snippets and snippet: + print(f"\n └─ \"{snippet}\"") + else: + print() + + elif event.event_type == TraceEventType.SPAN_HIGHLIGHTED: + for span in event.spans: + snippet = self._truncate(span.text) + conf = f"{span.confidence:.0%}" if span.confidence < 1.0 else "" + print(f"[{timestamp}] 🔍 [{span.highlight_type}] {conf}") + if self.show_snippets and snippet: + print(f" └─ \"{snippet}\"") + + elif event.event_type == TraceEventType.ASSOCIATION_CREATED: + assoc = event.association + if assoc: + src = self._truncate(assoc.source.text, 40) + tgt = self._truncate(assoc.target.text, 40) + print(f"[{timestamp}] 🔗 {assoc.association_type} ({assoc.confidence:.0%})") + print(f" ├─ \"{src}\"") + print(f" └─ \"{tgt}\"") + if assoc.reason: + print(f" ({assoc.reason})") + + elif event.event_type == TraceEventType.ENTITY_CREATED: + name = event.metadata.get("name", event.entity_id) + count = event.metadata.get("record_count", "?") + print(f"[{timestamp}] ✦ Entity created: {name} ({count} records)") + + elif event.event_type == TraceEventType.ENTITY_DERIVED: + name = event.metadata.get("name", event.entity_id) + sources = event.metadata.get("sources", []) + print(f"[{timestamp}] ⤵ Entity derived: {name} ← {len(sources)} sources") + + def _truncate(self, text: str, max_len: int = None) -> str: + max_len = max_len or self.max_snippet_len + if not text: + return "" + text = text.replace("\n", " ").strip() + if len(text) > max_len: + return text[:max_len-3] + "..." + return text + + +# ═══════════════════════════════════════════════════════════════════════════════ +# CONVENIENCE +# ═══════════════════════════════════════════════════════════════════════════════ + +def create_live_tracer(observer=None, console: bool = False) -> LiveDocumentTracer: + """ + Create a live document tracer. + + Args: + observer: DatasetObserver to hook into + console: If True, attach console renderer + + Returns: + Configured LiveDocumentTracer + """ + tracer = LiveDocumentTracer(observer) + + if console: + renderer = ConsoleTraceRenderer() + tracer.on_event(renderer.render) + + return tracer diff --git a/cascade/data/observer.py b/cascade/data/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..cd707825c9377b73f0b2ffa738ac6e5706474348 --- /dev/null +++ b/cascade/data/observer.py @@ -0,0 +1,666 @@ +""" +Dataset Observer + +The main interface for observing datasets. +Provides context managers for tracking ingest, transform, and consume operations. +""" + +import hashlib +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Generator, List, Optional, Union + +from .entities import ( + DatasetEntity, Activity, Agent, Relationship, RelationType, + ActivityType, AgentType, create_system_agent, create_model_agent, create_user_agent +) +from .provenance import ProvenanceGraph +from .schema import SchemaObserver, DatasetSchema, hash_content + + +@dataclass +class ObservationContext: + """ + Context for an ongoing observation. + + Used within context managers to track inputs/outputs. + """ + activity: Activity + observer: "DatasetObserver" + + _inputs: List[DatasetEntity] = field(default_factory=list) + _outputs: List[DatasetEntity] = field(default_factory=list) + + def input(self, dataset, name: str = None, **kwargs) -> DatasetEntity: + """ + Register an input dataset. + + Args: + dataset: HuggingFace Dataset, DatasetDict, or entity ID + name: Optional name override + **kwargs: Additional entity attributes + + Returns: + The created or retrieved DatasetEntity + """ + # If string, assume it's an existing entity ID + if isinstance(dataset, str): + entity = self.observer.graph.get_entity(dataset) + if entity: + self._inputs.append(entity) + self.activity.add_input(entity.id) + self.observer.graph.link_usage(self.activity.id, entity.id) + return entity + else: + raise ValueError(f"Entity not found: {dataset}") + + # Otherwise, observe the dataset + entity = self.observer.observe_dataset(dataset, name=name, **kwargs) + self._inputs.append(entity) + self.activity.add_input(entity.id) + self.observer.graph.link_usage(self.activity.id, entity.id) + + return entity + + def output(self, dataset, name: str = None, **kwargs) -> DatasetEntity: + """ + Register an output dataset. + + Args: + dataset: HuggingFace Dataset, DatasetDict, or dict + name: Optional name override + **kwargs: Additional entity attributes + + Returns: + The created DatasetEntity + """ + entity = self.observer.observe_dataset(dataset, name=name, **kwargs) + self._outputs.append(entity) + self.activity.add_output(entity.id) + + # Link generation + self.observer.graph.link_generation(entity.id, self.activity.id) + + # Link derivation from all inputs + for input_entity in self._inputs: + self.observer.graph.link_derivation(entity.id, input_entity.id) + + return entity + + @property + def inputs(self) -> List[DatasetEntity]: + return self._inputs + + @property + def outputs(self) -> List[DatasetEntity]: + return self._outputs + + +class DatasetObserver: + """ + Observer for dataset operations. + + Tracks: + - Dataset loading (ingest) + - Transformations (filter, map, join, etc.) + - Consumption (training, inference) + + Example: + observer = DatasetObserver() + + with observer.observe_ingest("squad") as ctx: + ds = load_dataset("squad") + ctx.output(ds) + + with observer.observe_transform("filter_english") as ctx: + ctx.input(ds) + filtered = ds.filter(lambda x: x["lang"] == "en") + ctx.output(filtered) + + chain = observer.export_provenance() + """ + + def __init__( + self, + name: str = "default", + agent: Agent = None, + ): + """ + Initialize observer. + + Args: + name: Name for the provenance graph + agent: Default agent for activities (defaults to graph's system agent) + """ + self.graph = ProvenanceGraph(name=name) + self.schema_observer = SchemaObserver() + + # Use provided agent or the graph's default system agent + if agent: + self._default_agent = agent + self.graph.add_agent(agent) + else: + # Use the graph's already-created system agent + self._default_agent = self.graph._system_agent + + # Entity counter for unique IDs + self._counter = 0 + + def _next_id(self, prefix: str) -> str: + """Generate unique ID.""" + self._counter += 1 + return f"{prefix}:{int(time.time() * 1000)}:{self._counter:04d}" + + # ═══════════════════════════════════════════════════════════════════════════ + # DATASET OBSERVATION + # ═══════════════════════════════════════════════════════════════════════════ + + def observe_dataset( + self, + dataset, + name: str = None, + source_type: str = None, + source_uri: str = None, + version: str = None, + license_id: str = None, + license_url: str = None, + **kwargs, + ) -> DatasetEntity: + """ + Observe a dataset and create an entity. + + Args: + dataset: HuggingFace Dataset, DatasetDict, DataFrame, or dict + name: Name for the entity + source_type: Type of source (hf_hub, local, etc.) + source_uri: URI of the source + version: Version string + license_id: SPDX license identifier (e.g., "MIT", "CC-BY-4.0") + license_url: URL to the license text + **kwargs: Additional attributes + + Returns: + DatasetEntity representing the dataset + """ + # Infer name if not provided + if name is None: + if hasattr(dataset, 'info') and hasattr(dataset.info, 'dataset_name'): + name = dataset.info.dataset_name + elif hasattr(dataset, 'config_name'): + name = dataset.config_name + else: + name = f"dataset_{self._counter + 1}" + + # Try to extract license from HuggingFace dataset info + if license_id is None and hasattr(dataset, 'info'): + info = dataset.info + if hasattr(info, 'license') and info.license: + license_id = info.license + + # Observe schema + schema = self._observe_schema(dataset) + + # Compute content hash + content_hash = self._compute_content_hash(dataset) + + # Get record count and splits + record_count, splits = self._get_counts(dataset) + + # Infer source + if source_type is None: + source_type = self._infer_source_type(dataset) + + # Create entity + entity = DatasetEntity( + id=self._next_id("entity"), + name=name, + content_hash=content_hash, + schema_hash=schema.hash() if schema else None, + version=version, + source_type=source_type, + source_uri=source_uri, + license_id=license_id, + license_url=license_url, + record_count=record_count, + splits=splits, + attributes={ + "schema": schema.to_dict() if schema else None, + **kwargs, + }, + ) + + # Add to graph + self.graph.add_entity(entity) + + return entity + + def register_agent(self, name: str, agent_type: str = "software", version: str = None) -> Agent: + """ + Register a new agent in the provenance graph. + + Args: + name: Name of the agent + agent_type: Type of agent (software, model, person, etc.) + version: Optional version string + + Returns: + The created Agent + """ + if agent_type == "model": + agent = create_model_agent(name, version=version) + elif agent_type == "system": + agent = create_system_agent(name, version=version) + elif agent_type == "person": + agent = create_user_agent(name) + else: + # Default to software agent or generic + try: + type_enum = AgentType(agent_type) + except ValueError: + type_enum = AgentType.SOFTWARE + + agent = Agent( + id=f"agent:{type_enum.value}:{name.replace(' ', '_').lower()}", + agent_type=type_enum, + name=name, + version=version + ) + + self.graph.add_agent(agent) + return agent + + def _observe_schema(self, dataset) -> Optional[DatasetSchema]: + """Extract schema from dataset.""" + try: + # HuggingFace Dataset + if hasattr(dataset, 'features'): + return self.schema_observer.observe_hf_dataset(dataset) + + # Pandas DataFrame + if hasattr(dataset, 'dtypes') and hasattr(dataset, 'columns'): + return self.schema_observer.observe_pandas(dataset) + + # Dict + if isinstance(dataset, dict): + # Check if it's columnar (dict of lists) + if all(isinstance(v, list) for v in dataset.values()): + return self.schema_observer.observe_dict(dataset) + + return None + except Exception as e: + # Don't fail observation if schema extraction fails + print(f"Warning: Could not extract schema: {e}") + return None + + def _compute_content_hash(self, dataset) -> str: + """Compute content hash of dataset.""" + try: + return hash_content(dataset) + except Exception: + # Fallback to timestamp-based hash + return hashlib.sha256(str(time.time()).encode()).hexdigest() + + def _get_counts(self, dataset) -> tuple: + """Get record count and split counts.""" + record_count = None + splits = {} + + try: + # HuggingFace DatasetDict + if hasattr(dataset, 'keys') and hasattr(dataset, '__getitem__'): + for split_name in dataset.keys(): + split_ds = dataset[split_name] + if hasattr(split_ds, '__len__'): + splits[split_name] = len(split_ds) + record_count = sum(splits.values()) if splits else None + + # Single dataset + elif hasattr(dataset, '__len__'): + record_count = len(dataset) + + except Exception: + pass + + return record_count, splits + + def _infer_source_type(self, dataset) -> str: + """Infer source type from dataset.""" + # HuggingFace Dataset + if hasattr(dataset, '_info'): + return "hf_dataset" + + # Pandas + if hasattr(dataset, 'dtypes'): + return "pandas" + + # Dict + if isinstance(dataset, dict): + return "dict" + + return "unknown" + + # ═══════════════════════════════════════════════════════════════════════════ + # CONTEXT MANAGERS + # ═══════════════════════════════════════════════════════════════════════════ + + @contextmanager + def observe_ingest( + self, + name: str, + source_uri: str = None, + agent: Agent = None, + **kwargs, + ) -> Generator[ObservationContext, None, None]: + """ + Observe a dataset ingest operation. + + Args: + name: Name of the ingest operation + source_uri: URI of the data source + agent: Agent performing the ingest + **kwargs: Additional activity parameters + + Yields: + ObservationContext for registering inputs/outputs + + Example: + with observer.observe_ingest("load_squad", source_uri="hf://squad") as ctx: + ds = load_dataset("squad") + ctx.output(ds, name="squad") + """ + activity = Activity( + id=self._next_id("activity"), + activity_type=ActivityType.INGEST, + name=name, + agent_id=(agent or self._default_agent).id, + parameters={"source_uri": source_uri, **kwargs}, + ) + activity.start() + + ctx = ObservationContext(activity=activity, observer=self) + + try: + yield ctx + finally: + activity.end() + self.graph.add_activity(activity) + self.graph.link_association(activity.id, activity.agent_id) + + @contextmanager + def observe_transform( + self, + name: str, + transform_type: str = None, + agent: Agent = None, + **kwargs, + ) -> Generator[ObservationContext, None, None]: + """ + Observe a dataset transformation. + + Args: + name: Name of the transform + transform_type: Type of transform (filter, map, join, etc.) + agent: Agent performing the transform + **kwargs: Additional activity parameters + + Yields: + ObservationContext for registering inputs/outputs + + Example: + with observer.observe_transform("filter_english") as ctx: + ctx.input(ds) + filtered = ds.filter(lambda x: x["lang"] == "en") + ctx.output(filtered) + """ + activity = Activity( + id=self._next_id("activity"), + activity_type=ActivityType.TRANSFORM, + name=name, + agent_id=(agent or self._default_agent).id, + parameters={"transform_type": transform_type, **kwargs}, + ) + activity.start() + + ctx = ObservationContext(activity=activity, observer=self) + + try: + yield ctx + finally: + activity.end() + self.graph.add_activity(activity) + self.graph.link_association(activity.id, activity.agent_id) + + @contextmanager + def observe_consume( + self, + name: str, + model_id: str = None, + consume_type: str = "train", + agent: Agent = None, + **kwargs, + ) -> Generator[ObservationContext, None, None]: + """ + Observe dataset consumption (training, inference). + + Args: + name: Name of the consumption operation + model_id: ID of the model consuming the data + consume_type: Type of consumption (train, evaluate, inference) + agent: Agent performing the consumption + **kwargs: Additional activity parameters + + Yields: + ObservationContext for registering inputs/outputs + + Example: + with observer.observe_consume("train_qa_model", model_id="bert-base") as ctx: + ctx.input(train_ds) + model = train(train_ds) + # Model provenance now links to data provenance! + """ + # Create model agent if model_id provided + if model_id and agent is None: + agent = create_model_agent(model_id) + self.graph.add_agent(agent) + + activity_type = { + "train": ActivityType.TRAIN, + "evaluate": ActivityType.EVALUATE, + "inference": ActivityType.INFERENCE, + }.get(consume_type, ActivityType.TRAIN) + + activity = Activity( + id=self._next_id("activity"), + activity_type=activity_type, + name=name, + agent_id=(agent or self._default_agent).id, + parameters={"model_id": model_id, "consume_type": consume_type, **kwargs}, + ) + activity.start() + + ctx = ObservationContext(activity=activity, observer=self) + + try: + yield ctx + finally: + activity.end() + self.graph.add_activity(activity) + self.graph.link_association(activity.id, activity.agent_id) + + @contextmanager + def observe_entity_resolution( + self, + name: str, + model_id: str = None, + threshold: float = None, + agent: Agent = None, + **kwargs, + ) -> Generator[ObservationContext, None, None]: + """ + Observe entity resolution / data unity operation. + + Args: + name: Name of the operation + model_id: Embedding model used + threshold: Similarity threshold + agent: Agent performing the operation + **kwargs: Additional parameters + + Example: + with observer.observe_entity_resolution("match_patients_claims") as ctx: + ctx.input(patients_ds) + ctx.input(claims_ds) + unified = run_unity(patients_ds, claims_ds) + ctx.output(unified) + """ + if model_id and agent is None: + agent = create_model_agent(model_id) + self.graph.add_agent(agent) + + activity = Activity( + id=self._next_id("activity"), + activity_type=ActivityType.ENTITY_RESOLUTION, + name=name, + agent_id=(agent or self._default_agent).id, + parameters={ + "model_id": model_id, + "threshold": threshold, + **kwargs, + }, + ) + activity.start() + + ctx = ObservationContext(activity=activity, observer=self) + + try: + yield ctx + finally: + activity.end() + self.graph.add_activity(activity) + self.graph.link_association(activity.id, activity.agent_id) + + # ═══════════════════════════════════════════════════════════════════════════ + # EXPORT + # ═══════════════════════════════════════════════════════════════════════════ + + def export_provenance(self) -> ProvenanceGraph: + """Export the provenance graph.""" + return self.graph + + def to_dict(self) -> Dict[str, Any]: + """Export observation state to dictionary.""" + return { + "graph": self.graph.to_dict(), + "counter": self._counter, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DatasetObserver": + """Load observer from dictionary.""" + observer = cls() + observer.graph = ProvenanceGraph.from_dict(data["graph"]) + observer._counter = data.get("counter", 0) + return observer + + # ═══════════════════════════════════════════════════════════════════════════ + # STATISTICS + # ═══════════════════════════════════════════════════════════════════════════ + + @property + def stats(self) -> Dict[str, Any]: + """Get observer statistics.""" + return { + "graph": self.graph.stats, + "root_hash": self.graph.root_hash, + } + + # ═══════════════════════════════════════════════════════════════════════════ + # LICENSE TRACKING + # ═══════════════════════════════════════════════════════════════════════════ + + def check_license_compatibility( + self, + entity_ids: List[str], + target_license: str = None, + ): + """ + Check license compatibility for deriving from entities. + + Args: + entity_ids: List of source entity IDs + target_license: Intended SPDX license for derived work + + Returns: + LicenseCompatibility result + + Example: + result = observer.check_license_compatibility( + ["entity:123", "entity:456"], + target_license="MIT" + ) + if not result.compatible: + print(f"Issues: {result.issues}") + """ + from .license import check_license_compatibility + + sources = [] + for entity_id in entity_ids: + entity = self.graph.get_entity(entity_id) + if entity: + license_id = entity.license_id or "unknown" + sources.append((entity_id, license_id)) + + return check_license_compatibility(sources, target_license) + + def get_derived_license(self, entity_ids: List[str]): + """ + Get the appropriate license for a work derived from entities. + + Args: + entity_ids: List of source entity IDs + + Returns: + SPDXLicense for the derived work + """ + from .license import get_derived_license + + licenses = [] + for entity_id in entity_ids: + entity = self.graph.get_entity(entity_id) + if entity and entity.license_id: + licenses.append(entity.license_id) + + return get_derived_license(licenses) if licenses else None + + def generate_attribution(self, entity_ids: List[str] = None) -> str: + """ + Generate attribution text for entities. + + Args: + entity_ids: List of entity IDs (defaults to all entities) + + Returns: + Markdown attribution text + """ + from .license import LicenseAnalyzer + + analyzer = LicenseAnalyzer() + + if entity_ids is None: + entities = self.graph.list_entities() + else: + entities = [ + self.graph.get_entity(eid) for eid in entity_ids + if self.graph.get_entity(eid) + ] + + sources = [ + (e.id, e.license_id or "unknown", e.name) + for e in entities + ] + + return analyzer.generate_attribution(sources) + + def __repr__(self) -> str: + return f"DatasetObserver({self.graph})" diff --git a/cascade/data/pii.py b/cascade/data/pii.py new file mode 100644 index 0000000000000000000000000000000000000000..82e5bf51a17a50315ee07d58b7af91b5f338dcc8 --- /dev/null +++ b/cascade/data/pii.py @@ -0,0 +1,748 @@ +""" +PII Detection for CASCADE + +Industry standard PII (Personally Identifiable Information) detection +based on Microsoft Presidio patterns and common PII taxonomies. + +References: +- Microsoft Presidio: https://github.com/microsoft/presidio +- NIST PII Guide: https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-122.pdf +- GDPR Article 4 (personal data definition) + +PII Categories: +1. Direct Identifiers: Name, SSN, passport, driver's license +2. Quasi-Identifiers: Age, ZIP code, gender, dates +3. Sensitive Data: Health, financial, biometric + +Detection Methods: +- Regex patterns (fast, high precision for structured PII) +- Context-aware detection (surrounding words improve accuracy) +- Checksum validation (SSN, credit cards, etc.) +""" + +import re +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Pattern, Set, Tuple + + +class PIIType(Enum): + """Types of PII that can be detected.""" + # Direct Identifiers + PERSON_NAME = "PERSON_NAME" + EMAIL = "EMAIL" + PHONE_NUMBER = "PHONE_NUMBER" + SSN = "SSN" # Social Security Number + CREDIT_CARD = "CREDIT_CARD" + IBAN = "IBAN" # International Bank Account Number + IP_ADDRESS = "IP_ADDRESS" + MAC_ADDRESS = "MAC_ADDRESS" + PASSPORT = "PASSPORT" + DRIVERS_LICENSE = "DRIVERS_LICENSE" + + # Quasi-Identifiers + DATE_OF_BIRTH = "DATE_OF_BIRTH" + AGE = "AGE" + ZIPCODE = "ZIPCODE" + ADDRESS = "ADDRESS" + + # Sensitive Data + MEDICAL_RECORD = "MEDICAL_RECORD" + API_KEY = "API_KEY" + AWS_KEY = "AWS_KEY" + PASSWORD = "PASSWORD" + CRYPTO_WALLET = "CRYPTO_WALLET" + + # Location + GPS_COORDINATES = "GPS_COORDINATES" + + # URLs and IDs + URL = "URL" + USERNAME = "USERNAME" + + +class PIISeverity(Enum): + """Severity levels for PII findings.""" + CRITICAL = "critical" # Direct identifier, immediate re-identification risk + HIGH = "high" # Sensitive data, significant privacy risk + MEDIUM = "medium" # Quasi-identifier, re-identification when combined + LOW = "low" # Minimal risk, contextual sensitivity + + +@dataclass +class PIIMatch: + """A detected PII instance.""" + pii_type: PIIType + severity: PIISeverity + value: str # The matched text (may be redacted for display) + start: int # Start position in text + end: int # End position in text + confidence: float # 0.0 to 1.0 + context: str = "" # Surrounding text for context + field_name: str = "" # Column/field where found + row_index: int = -1 # Row index if applicable + + def to_dict(self) -> Dict[str, Any]: + return { + "type": self.pii_type.value, + "severity": self.severity.value, + "value_preview": self._redact(self.value), + "start": self.start, + "end": self.end, + "confidence": self.confidence, + "field_name": self.field_name, + "row_index": self.row_index, + } + + def _redact(self, value: str, show_chars: int = 4) -> str: + """Partially redact the value for display.""" + if len(value) <= show_chars: + return "*" * len(value) + return value[:show_chars] + "*" * (len(value) - show_chars) + + +@dataclass +class PIIPattern: + """A pattern for detecting PII.""" + pii_type: PIIType + severity: PIISeverity + pattern: Pattern + confidence: float = 0.85 + validator: Optional[Callable[[str], bool]] = None # Additional validation + context_patterns: List[str] = field(default_factory=list) # Boost confidence if context matches + + +@dataclass +class PIIScanResult: + """Result of scanning content for PII.""" + total_matches: int = 0 + matches_by_type: Dict[str, int] = field(default_factory=dict) + matches_by_severity: Dict[str, int] = field(default_factory=dict) + matches_by_field: Dict[str, int] = field(default_factory=dict) + sample_matches: List[PIIMatch] = field(default_factory=list) # First N matches + fields_with_pii: Set[str] = field(default_factory=set) + high_risk_fields: Set[str] = field(default_factory=set) # Fields with CRITICAL/HIGH PII + + def to_dict(self) -> Dict[str, Any]: + return { + "total_matches": self.total_matches, + "matches_by_type": self.matches_by_type, + "matches_by_severity": self.matches_by_severity, + "matches_by_field": self.matches_by_field, + "fields_with_pii": list(self.fields_with_pii), + "high_risk_fields": list(self.high_risk_fields), + "sample_matches": [m.to_dict() for m in self.sample_matches[:10]], + } + + def has_critical_pii(self) -> bool: + """Check if any critical PII was found.""" + return self.matches_by_severity.get("critical", 0) > 0 + + def has_high_risk_pii(self) -> bool: + """Check if any high-risk PII was found.""" + return ( + self.matches_by_severity.get("critical", 0) > 0 or + self.matches_by_severity.get("high", 0) > 0 + ) + + @property + def summary(self) -> str: + """Human-readable summary.""" + if self.total_matches == 0: + return "No PII detected" + + lines = [f"Found {self.total_matches} PII instance(s):"] + for sev in ["critical", "high", "medium", "low"]: + count = self.matches_by_severity.get(sev, 0) + if count > 0: + lines.append(f" • {sev.upper()}: {count}") + + if self.high_risk_fields: + lines.append(f" ⚠ High-risk fields: {', '.join(self.high_risk_fields)}") + + return "\n".join(lines) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# VALIDATION FUNCTIONS +# ═══════════════════════════════════════════════════════════════════════════════ + +def validate_luhn(card_number: str) -> bool: + """ + Validate credit card using Luhn algorithm. + + Used by Visa, MasterCard, American Express, etc. + """ + digits = [int(d) for d in re.sub(r'\D', '', card_number)] + if len(digits) < 13 or len(digits) > 19: + return False + + # Luhn checksum + checksum = 0 + for i, digit in enumerate(reversed(digits)): + if i % 2 == 1: + digit *= 2 + if digit > 9: + digit -= 9 + checksum += digit + + return checksum % 10 == 0 + + +def validate_ssn(ssn: str) -> bool: + """ + Validate US Social Security Number format. + + SSN format: AAA-BB-CCCC + - AAA: Area number (001-899, excluding 666) + - BB: Group number (01-99) + - CCCC: Serial number (0001-9999) + """ + clean = re.sub(r'\D', '', ssn) + if len(clean) != 9: + return False + + area = int(clean[:3]) + group = int(clean[3:5]) + serial = int(clean[5:]) + + # Invalid patterns + if area == 0 or area == 666 or area >= 900: + return False + if group == 0: + return False + if serial == 0: + return False + + # Known invalid SSNs (advertising, testing) + invalid_ssns = { + "078051120", # Woolworth promotional + "219099999", # Advertising + } + if clean in invalid_ssns: + return False + + return True + + +def validate_iban(iban: str) -> bool: + """ + Validate IBAN using MOD-97 checksum. + """ + clean = re.sub(r'\s', '', iban).upper() + if len(clean) < 15 or len(clean) > 34: + return False + + # Move country code and check digits to end + rearranged = clean[4:] + clean[:4] + + # Convert letters to numbers (A=10, B=11, etc.) + numeric = "" + for char in rearranged: + if char.isdigit(): + numeric += char + else: + numeric += str(ord(char) - ord('A') + 10) + + # MOD 97 check + return int(numeric) % 97 == 1 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# PII PATTERNS (Based on Microsoft Presidio) +# ═══════════════════════════════════════════════════════════════════════════════ + +PII_PATTERNS: List[PIIPattern] = [ + # Email - RFC 5322 simplified + PIIPattern( + pii_type=PIIType.EMAIL, + severity=PIISeverity.HIGH, + pattern=re.compile( + r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', + re.IGNORECASE + ), + confidence=0.95, + context_patterns=["email", "e-mail", "contact", "mail"], + ), + + # Phone Number - International formats + PIIPattern( + pii_type=PIIType.PHONE_NUMBER, + severity=PIISeverity.MEDIUM, + pattern=re.compile( + r''' + (?: + \+?1?[-.\s]? # Country code + \(?[2-9]\d{2}\)?[-.\s]? # Area code + [2-9]\d{2}[-.\s]? # Exchange + \d{4} # Subscriber + | + \+?\d{1,3}[-.\s]?\(?\d{1,4}\)?[-.\s]? # International + \d{1,4}[-.\s]?\d{1,9} + ) + ''', + re.VERBOSE + ), + confidence=0.75, + context_patterns=["phone", "tel", "mobile", "cell", "call", "fax"], + ), + + # SSN - US Social Security Number + PIIPattern( + pii_type=PIIType.SSN, + severity=PIISeverity.CRITICAL, + pattern=re.compile( + r'\b(?!000|666|9\d{2})\d{3}[-\s]?(?!00)\d{2}[-\s]?(?!0000)\d{4}\b' + ), + confidence=0.85, + validator=validate_ssn, + context_patterns=["ssn", "social security", "tax id", "taxpayer"], + ), + + # Credit Card - Major card formats + PIIPattern( + pii_type=PIIType.CREDIT_CARD, + severity=PIISeverity.CRITICAL, + pattern=re.compile( + r''' + \b(?: + 4[0-9]{12}(?:[0-9]{3})? # Visa + | + 5[1-5][0-9]{14} # MasterCard + | + 3[47][0-9]{13} # American Express + | + 6(?:011|5[0-9]{2})[0-9]{12} # Discover + | + (?:2131|1800|35\d{3})\d{11} # JCB + )\b + | + \b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b # Spaced format + ''', + re.VERBOSE + ), + confidence=0.90, + validator=validate_luhn, + context_patterns=["card", "credit", "visa", "mastercard", "amex", "payment"], + ), + + # IP Address - IPv4 + PIIPattern( + pii_type=PIIType.IP_ADDRESS, + severity=PIISeverity.MEDIUM, + pattern=re.compile( + r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b' + ), + confidence=0.90, + context_patterns=["ip", "address", "server", "host", "client"], + ), + + # IP Address - IPv6 + PIIPattern( + pii_type=PIIType.IP_ADDRESS, + severity=PIISeverity.MEDIUM, + pattern=re.compile( + r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b' + ), + confidence=0.90, + ), + + # MAC Address + PIIPattern( + pii_type=PIIType.MAC_ADDRESS, + severity=PIISeverity.LOW, + pattern=re.compile( + r'\b(?:[0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b' + ), + confidence=0.95, + ), + + # IBAN - International Bank Account Number + PIIPattern( + pii_type=PIIType.IBAN, + severity=PIISeverity.CRITICAL, + pattern=re.compile( + r'\b[A-Z]{2}\d{2}[A-Z0-9]{4}\d{7}(?:[A-Z0-9]?){0,16}\b', + re.IGNORECASE + ), + confidence=0.85, + validator=validate_iban, + context_patterns=["iban", "bank", "account", "transfer"], + ), + + # API Key patterns + PIIPattern( + pii_type=PIIType.API_KEY, + severity=PIISeverity.CRITICAL, + pattern=re.compile( + r''' + (?: + sk[-_]live[-_][a-zA-Z0-9]{24,} # Stripe + | + sk[-_]test[-_][a-zA-Z0-9]{24,} # Stripe test + | + pk[-_]live[-_][a-zA-Z0-9]{24,} # Stripe public + | + ghp_[a-zA-Z0-9]{36} # GitHub PAT + | + gho_[a-zA-Z0-9]{36} # GitHub OAuth + | + github_pat_[a-zA-Z0-9]{22}_[a-zA-Z0-9]{59} # GitHub fine-grained + | + xox[baprs]-[a-zA-Z0-9-]{10,} # Slack + | + ya29\.[a-zA-Z0-9_-]+ # Google OAuth + ) + ''', + re.VERBOSE + ), + confidence=0.95, + context_patterns=["api", "key", "token", "secret", "auth"], + ), + + # AWS Access Key + PIIPattern( + pii_type=PIIType.AWS_KEY, + severity=PIISeverity.CRITICAL, + pattern=re.compile( + r'\b(?:AKIA|ABIA|ACCA|ASIA)[A-Z0-9]{16}\b' + ), + confidence=0.95, + context_patterns=["aws", "amazon", "key", "access"], + ), + + # Crypto Wallet - Bitcoin + PIIPattern( + pii_type=PIIType.CRYPTO_WALLET, + severity=PIISeverity.HIGH, + pattern=re.compile( + r'\b(?:bc1|[13])[a-zA-HJ-NP-Z0-9]{25,39}\b' + ), + confidence=0.80, + context_patterns=["bitcoin", "btc", "wallet", "crypto"], + ), + + # Crypto Wallet - Ethereum + PIIPattern( + pii_type=PIIType.CRYPTO_WALLET, + severity=PIISeverity.HIGH, + pattern=re.compile( + r'\b0x[a-fA-F0-9]{40}\b' + ), + confidence=0.80, + context_patterns=["ethereum", "eth", "wallet", "crypto"], + ), + + # GPS Coordinates + PIIPattern( + pii_type=PIIType.GPS_COORDINATES, + severity=PIISeverity.MEDIUM, + pattern=re.compile( + r'[-+]?(?:[1-8]?\d(?:\.\d+)?|90(?:\.0+)?)\s*,\s*[-+]?(?:180(?:\.0+)?|(?:(?:1[0-7]\d)|(?:[1-9]?\d))(?:\.\d+)?)' + ), + confidence=0.70, + context_patterns=["location", "coordinates", "lat", "lng", "gps"], + ), + + # Date of Birth patterns + PIIPattern( + pii_type=PIIType.DATE_OF_BIRTH, + severity=PIISeverity.MEDIUM, + pattern=re.compile( + r'\b(?:0?[1-9]|1[0-2])[/\-.](?:0?[1-9]|[12]\d|3[01])[/\-.](?:19|20)\d{2}\b' + ), + confidence=0.60, # Low base - needs context + context_patterns=["birth", "dob", "born", "birthday", "date of birth"], + ), + + # US ZIP Code + PIIPattern( + pii_type=PIIType.ZIPCODE, + severity=PIISeverity.LOW, + pattern=re.compile( + r'\b\d{5}(?:-\d{4})?\b' + ), + confidence=0.50, # Low - needs context + context_patterns=["zip", "postal", "address", "code"], + ), + + # URL (can contain sensitive info in path/query) + PIIPattern( + pii_type=PIIType.URL, + severity=PIISeverity.LOW, + pattern=re.compile( + r'https?://[^\s<>"{}|\\^`\[\]]+', + re.IGNORECASE + ), + confidence=0.70, + ), +] + + +class PIIScanner: + """ + Scanner for detecting PII in text and datasets. + + Uses regex patterns with optional validation and context boosting. + """ + + def __init__( + self, + patterns: List[PIIPattern] = None, + min_confidence: float = 0.5, + context_boost: float = 0.1, + ): + """ + Initialize scanner. + + Args: + patterns: Custom patterns (defaults to PII_PATTERNS) + min_confidence: Minimum confidence to report (0.0-1.0) + context_boost: Confidence boost when context matches + """ + self.patterns = patterns or PII_PATTERNS + self.min_confidence = min_confidence + self.context_boost = context_boost + + def scan_text( + self, + text: str, + field_name: str = "", + row_index: int = -1, + ) -> List[PIIMatch]: + """ + Scan text for PII. + + Args: + text: Text to scan + field_name: Optional field name for tracking + row_index: Optional row index for tracking + + Returns: + List of PIIMatch objects + """ + if not text or not isinstance(text, str): + return [] + + matches = [] + text_lower = text.lower() + + for pattern in self.patterns: + for match in pattern.pattern.finditer(text): + value = match.group() + confidence = pattern.confidence + + # Validate if validator provided + if pattern.validator: + if not pattern.validator(value): + continue + + # Context boost + if pattern.context_patterns: + for ctx in pattern.context_patterns: + if ctx in text_lower: + confidence = min(1.0, confidence + self.context_boost) + break + + # Apply minimum confidence filter + if confidence >= self.min_confidence: + # Get surrounding context (50 chars each side) + start = max(0, match.start() - 50) + end = min(len(text), match.end() + 50) + context = text[start:end] + + matches.append(PIIMatch( + pii_type=pattern.pii_type, + severity=pattern.severity, + value=value, + start=match.start(), + end=match.end(), + confidence=confidence, + context=context, + field_name=field_name, + row_index=row_index, + )) + + return matches + + def scan_dict( + self, + data: Dict[str, List[Any]], + sample_size: int = 1000, + ) -> PIIScanResult: + """ + Scan a columnar dict for PII. + + Args: + data: Dict of column_name -> values + sample_size: Max rows to scan per column + + Returns: + PIIScanResult with aggregated findings + """ + result = PIIScanResult() + + for field_name, values in data.items(): + if not values: + continue + + # Sample values + sample = values[:sample_size] + + for row_idx, value in enumerate(sample): + if not isinstance(value, str): + value = str(value) if value is not None else "" + + matches = self.scan_text(value, field_name, row_idx) + + for match in matches: + result.total_matches += 1 + + # Count by type + type_name = match.pii_type.value + result.matches_by_type[type_name] = result.matches_by_type.get(type_name, 0) + 1 + + # Count by severity + sev = match.severity.value + result.matches_by_severity[sev] = result.matches_by_severity.get(sev, 0) + 1 + + # Count by field + result.matches_by_field[field_name] = result.matches_by_field.get(field_name, 0) + 1 + + # Track fields + result.fields_with_pii.add(field_name) + if match.severity in [PIISeverity.CRITICAL, PIISeverity.HIGH]: + result.high_risk_fields.add(field_name) + + # Keep samples + if len(result.sample_matches) < 100: + result.sample_matches.append(match) + + return result + + def scan_dataset( + self, + dataset, + sample_size: int = 1000, + ) -> PIIScanResult: + """ + Scan a HuggingFace Dataset or DatasetDict for PII. + + Args: + dataset: HuggingFace Dataset or DatasetDict + sample_size: Max rows to scan + + Returns: + PIIScanResult with aggregated findings + """ + # Handle DatasetDict (multiple splits) + if hasattr(dataset, 'keys') and callable(dataset.keys): + combined = PIIScanResult() + for split_name in dataset.keys(): + split_result = self.scan_dataset(dataset[split_name], sample_size) + # Merge results + combined.total_matches += split_result.total_matches + for k, v in split_result.matches_by_type.items(): + combined.matches_by_type[k] = combined.matches_by_type.get(k, 0) + v + for k, v in split_result.matches_by_severity.items(): + combined.matches_by_severity[k] = combined.matches_by_severity.get(k, 0) + v + for k, v in split_result.matches_by_field.items(): + combined.matches_by_field[k] = combined.matches_by_field.get(k, 0) + v + combined.fields_with_pii.update(split_result.fields_with_pii) + combined.high_risk_fields.update(split_result.high_risk_fields) + combined.sample_matches.extend(split_result.sample_matches[:20]) + return combined + + # Single Dataset + result = PIIScanResult() + + # Get column names + if hasattr(dataset, 'features'): + columns = list(dataset.features.keys()) + elif hasattr(dataset, 'column_names'): + columns = dataset.column_names + else: + return result + + # Sample rows + num_rows = len(dataset) if hasattr(dataset, '__len__') else sample_size + sample_indices = range(min(sample_size, num_rows)) + + for idx in sample_indices: + row = dataset[idx] + for col in columns: + value = row.get(col) if isinstance(row, dict) else getattr(row, col, None) + if not isinstance(value, str): + value = str(value) if value is not None else "" + + matches = self.scan_text(value, col, idx) + + for match in matches: + result.total_matches += 1 + + type_name = match.pii_type.value + result.matches_by_type[type_name] = result.matches_by_type.get(type_name, 0) + 1 + + sev = match.severity.value + result.matches_by_severity[sev] = result.matches_by_severity.get(sev, 0) + 1 + + result.matches_by_field[col] = result.matches_by_field.get(col, 0) + 1 + + result.fields_with_pii.add(col) + if match.severity in [PIISeverity.CRITICAL, PIISeverity.HIGH]: + result.high_risk_fields.add(col) + + if len(result.sample_matches) < 100: + result.sample_matches.append(match) + + return result + + +# Singleton scanner +_scanner = PIIScanner() + + +def scan_for_pii( + data, + sample_size: int = 1000, + min_confidence: float = 0.5, +) -> PIIScanResult: + """ + Convenience function to scan data for PII. + + Args: + data: Text, dict, or HuggingFace Dataset + sample_size: Max rows to scan + min_confidence: Minimum confidence threshold + + Returns: + PIIScanResult with findings + """ + scanner = PIIScanner(min_confidence=min_confidence) + + if isinstance(data, str): + matches = scanner.scan_text(data) + result = PIIScanResult( + total_matches=len(matches), + sample_matches=matches, + ) + for m in matches: + result.matches_by_type[m.pii_type.value] = result.matches_by_type.get(m.pii_type.value, 0) + 1 + result.matches_by_severity[m.severity.value] = result.matches_by_severity.get(m.severity.value, 0) + 1 + return result + + if isinstance(data, dict): + return scanner.scan_dict(data, sample_size) + + # Assume HuggingFace Dataset + return scanner.scan_dataset(data, sample_size) + + +def quick_pii_check(data, sample_size: int = 100) -> bool: + """ + Quick check if data contains any PII. + + Returns True if PII is found, False otherwise. + """ + result = scan_for_pii(data, sample_size=sample_size, min_confidence=0.7) + return result.total_matches > 0 diff --git a/cascade/data/provenance.py b/cascade/data/provenance.py new file mode 100644 index 0000000000000000000000000000000000000000..b819c6933cf95b27e592ec17a993f9b0e8187104 --- /dev/null +++ b/cascade/data/provenance.py @@ -0,0 +1,503 @@ +""" +Provenance Graph + +Tracks entities, activities, agents, and their relationships. +Supports Merkle tree hashing for tamper-evident lineage. +""" + +import hashlib +import json +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple, Iterator + +from .entities import ( + DatasetEntity, Activity, Agent, Relationship, RelationType, + ActivityType, AgentType, create_system_agent +) + + +@dataclass +class ProvenanceNode: + """A node in the provenance graph with hash chain.""" + node_id: str + node_type: str # entity, activity, agent + data: Dict[str, Any] + + # Hash chain + node_hash: str = "" + parent_hashes: List[str] = field(default_factory=list) + + def __post_init__(self): + if not self.node_hash: + self.node_hash = self._compute_hash() + + def _compute_hash(self) -> str: + """Compute hash including parent hashes (Merkle-style).""" + content = json.dumps({ + "id": self.node_id, + "type": self.node_type, + "data": self.data, + "parents": sorted(self.parent_hashes), + }, sort_keys=True, default=str) + return hashlib.sha256(content.encode()).hexdigest() + + +class ProvenanceGraph: + """ + A graph of provenance relationships. + + Tracks: + - Entities (datasets, versions, splits) + - Activities (transforms, training, inference) + - Agents (users, models, pipelines) + - Relationships between them + + Provides: + - Lineage queries (what produced this? what did this produce?) + - Hash chain for integrity verification + - Export to PROV-O and Croissant formats + """ + + def __init__(self, name: str = "default"): + self.name = name + self.created_at = time.time() + + # Storage + self._entities: Dict[str, DatasetEntity] = {} + self._activities: Dict[str, Activity] = {} + self._agents: Dict[str, Agent] = {} + self._relationships: List[Relationship] = [] + + # Hash chain + self._nodes: Dict[str, ProvenanceNode] = {} + self._root_hash: Optional[str] = None + + # Default system agent + self._system_agent = create_system_agent("cascade-data-observatory") + self.add_agent(self._system_agent) + + # ═══════════════════════════════════════════════════════════════════════════ + # ENTITY MANAGEMENT + # ═══════════════════════════════════════════════════════════════════════════ + + def add_entity(self, entity: DatasetEntity) -> str: + """Add a dataset entity to the graph.""" + self._entities[entity.id] = entity + + # Create provenance node + node = ProvenanceNode( + node_id=entity.id, + node_type="entity", + data=entity.to_dict(), + ) + self._nodes[entity.id] = node + self._update_root_hash() + + return entity.id + + def get_entity(self, entity_id: str) -> Optional[DatasetEntity]: + """Get entity by ID.""" + return self._entities.get(entity_id) + + def list_entities(self) -> List[DatasetEntity]: + """List all entities.""" + return list(self._entities.values()) + + # ═══════════════════════════════════════════════════════════════════════════ + # ACTIVITY MANAGEMENT + # ═══════════════════════════════════════════════════════════════════════════ + + def add_activity(self, activity: Activity) -> str: + """Add an activity to the graph.""" + self._activities[activity.id] = activity + + # Link to agent + if not activity.agent_id: + activity.agent_id = self._system_agent.id + + # Create provenance node with parent hashes from inputs + parent_hashes = [] + for input_id in activity.inputs: + if input_id in self._nodes: + parent_hashes.append(self._nodes[input_id].node_hash) + + node = ProvenanceNode( + node_id=activity.id, + node_type="activity", + data=activity.to_dict(), + parent_hashes=parent_hashes, + ) + self._nodes[activity.id] = node + self._update_root_hash() + + return activity.id + + def get_activity(self, activity_id: str) -> Optional[Activity]: + """Get activity by ID.""" + return self._activities.get(activity_id) + + def list_activities(self) -> List[Activity]: + """List all activities.""" + return list(self._activities.values()) + + # ═══════════════════════════════════════════════════════════════════════════ + # AGENT MANAGEMENT + # ═══════════════════════════════════════════════════════════════════════════ + + def add_agent(self, agent: Agent) -> str: + """Add an agent to the graph.""" + self._agents[agent.id] = agent + + node = ProvenanceNode( + node_id=agent.id, + node_type="agent", + data=agent.to_dict(), + ) + self._nodes[agent.id] = node + + return agent.id + + def get_agent(self, agent_id: str) -> Optional[Agent]: + """Get agent by ID.""" + return self._agents.get(agent_id) + + def list_agents(self) -> List[Agent]: + """List all agents.""" + return list(self._agents.values()) + + def list_relationships(self) -> List[Relationship]: + """List all relationships.""" + return list(self._relationships) + + # ═══════════════════════════════════════════════════════════════════════════ + # RELATIONSHIP MANAGEMENT + # ═══════════════════════════════════════════════════════════════════════════ + + def add_relationship( + self, + relation_type: RelationType, + source_id: str, + target_id: str, + attributes: Dict[str, Any] = None, + timestamp: float = None, + ) -> Relationship: + """Add a relationship between nodes.""" + rel = Relationship( + relation_type=relation_type, + source_id=source_id, + target_id=target_id, + timestamp=timestamp if timestamp is not None else time.time(), + attributes=attributes or {}, + ) + self._relationships.append(rel) + return rel + + def link_derivation(self, derived_id: str, source_id: str) -> Relationship: + """Record that derived entity came from source entity.""" + return self.add_relationship( + RelationType.WAS_DERIVED_FROM, + source_id=derived_id, + target_id=source_id, + ) + + def link_generation(self, entity_id: str, activity_id: str) -> Relationship: + """Record that entity was generated by activity.""" + return self.add_relationship( + RelationType.WAS_GENERATED_BY, + source_id=entity_id, + target_id=activity_id, + ) + + def link_usage(self, activity_id: str, entity_id: str) -> Relationship: + """Record that activity used entity as input.""" + return self.add_relationship( + RelationType.USED, + source_id=activity_id, + target_id=entity_id, + ) + + def link_attribution(self, entity_id: str, agent_id: str) -> Relationship: + """Record that entity was attributed to agent.""" + return self.add_relationship( + RelationType.WAS_ATTRIBUTED_TO, + source_id=entity_id, + target_id=agent_id, + ) + + def link_association(self, activity_id: str, agent_id: str) -> Relationship: + """Record that activity was associated with agent.""" + return self.add_relationship( + RelationType.WAS_ASSOCIATED_WITH, + source_id=activity_id, + target_id=agent_id, + ) + + # ═══════════════════════════════════════════════════════════════════════════ + # LINEAGE QUERIES + # ═══════════════════════════════════════════════════════════════════════════ + + def get_lineage(self, entity_id: str, direction: str = "upstream") -> List[str]: + """ + Get lineage for an entity. + + Args: + entity_id: The entity to trace + direction: "upstream" (what produced this) or "downstream" (what this produced) + + Returns: + List of entity IDs in lineage order + """ + visited: Set[str] = set() + lineage: List[str] = [] + + def trace(current_id: str): + if current_id in visited: + return + visited.add(current_id) + + for rel in self._relationships: + if direction == "upstream": + # Follow wasDerivedFrom backwards + if rel.relation_type == RelationType.WAS_DERIVED_FROM: + if rel.source_id == current_id: + lineage.append(rel.target_id) + trace(rel.target_id) + else: + # Follow wasDerivedFrom forwards + if rel.relation_type == RelationType.WAS_DERIVED_FROM: + if rel.target_id == current_id: + lineage.append(rel.source_id) + trace(rel.source_id) + + trace(entity_id) + return lineage + + def get_activities_for_entity(self, entity_id: str) -> List[Activity]: + """Get activities that generated or used this entity.""" + activity_ids = set() + + for rel in self._relationships: + if rel.relation_type == RelationType.WAS_GENERATED_BY: + if rel.source_id == entity_id: + activity_ids.add(rel.target_id) + elif rel.relation_type == RelationType.USED: + if rel.target_id == entity_id: + activity_ids.add(rel.source_id) + + return [self._activities[aid] for aid in activity_ids if aid in self._activities] + + def get_inputs_for_activity(self, activity_id: str) -> List[DatasetEntity]: + """Get entities that were inputs to an activity.""" + entity_ids = set() + + for rel in self._relationships: + if rel.relation_type == RelationType.USED: + if rel.source_id == activity_id: + entity_ids.add(rel.target_id) + + return [self._entities[eid] for eid in entity_ids if eid in self._entities] + + def get_outputs_for_activity(self, activity_id: str) -> List[DatasetEntity]: + """Get entities that were outputs of an activity.""" + entity_ids = set() + + for rel in self._relationships: + if rel.relation_type == RelationType.WAS_GENERATED_BY: + if rel.target_id == activity_id: + entity_ids.add(rel.source_id) + + return [self._entities[eid] for eid in entity_ids if eid in self._entities] + + # ═══════════════════════════════════════════════════════════════════════════ + # HASH CHAIN + # ═══════════════════════════════════════════════════════════════════════════ + + def _update_root_hash(self): + """Update the Merkle root hash.""" + if not self._nodes: + self._root_hash = None + return + + # Compute root from all node hashes + all_hashes = sorted([n.node_hash for n in self._nodes.values()]) + combined = "".join(all_hashes) + self._root_hash = hashlib.sha256(combined.encode()).hexdigest() + + @property + def root_hash(self) -> Optional[str]: + """Get the current Merkle root hash.""" + return self._root_hash + + def verify_integrity(self) -> Tuple[bool, List[str]]: + """ + Verify integrity of the provenance graph. + + Returns: + (is_valid, list of invalid node IDs) + """ + invalid = [] + + for node_id, node in self._nodes.items(): + expected_hash = node._compute_hash() + if expected_hash != node.node_hash: + invalid.append(node_id) + + return len(invalid) == 0, invalid + + # ═══════════════════════════════════════════════════════════════════════════ + # EXPORT + # ═══════════════════════════════════════════════════════════════════════════ + + def to_dict(self) -> Dict[str, Any]: + """Export graph to dictionary.""" + return { + "name": self.name, + "created_at": self.created_at, + "root_hash": self._root_hash, + "entities": {k: v.to_dict() for k, v in self._entities.items()}, + "activities": {k: v.to_dict() for k, v in self._activities.items()}, + "agents": {k: v.to_dict() for k, v in self._agents.items()}, + "relationships": [r.to_dict() for r in self._relationships], + } + + def to_prov_n(self) -> str: + """Export as PROV-N notation.""" + lines = [ + f"document", + f" prefix cascade ", + f" prefix prov ", + f"", + ] + + # Entities + for entity in self._entities.values(): + lines.append(f" {entity.to_prov_n()}") + + lines.append("") + + # Activities + for activity in self._activities.values(): + lines.append(f" {activity.to_prov_n()}") + + lines.append("") + + # Agents + for agent in self._agents.values(): + lines.append(f" {agent.to_prov_n()}") + + lines.append("") + + # Relationships + for rel in self._relationships: + lines.append(f" {rel.to_prov_n()}") + + lines.append("") + lines.append("endDocument") + + return "\n".join(lines) + + def to_prov_jsonld(self) -> Dict[str, Any]: + """Export as PROV-O JSON-LD.""" + return { + "@context": { + "prov": "http://www.w3.org/ns/prov#", + "cascade": "https://cascade.ai/ns/", + "xsd": "http://www.w3.org/2001/XMLSchema#", + }, + "@graph": [ + *[e.to_dict() for e in self._entities.values()], + *[a.to_dict() for a in self._activities.values()], + *[a.to_dict() for a in self._agents.values()], + ], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ProvenanceGraph": + """Load graph from dictionary.""" + graph = cls(name=data.get("name", "default")) + graph.created_at = data.get("created_at", time.time()) + + # Load entities + for entity_data in data.get("entities", {}).values(): + entity = DatasetEntity( + id=entity_data["@id"], + name=entity_data["name"], + content_hash=entity_data.get("content_hash"), + schema_hash=entity_data.get("schema_hash"), + version=entity_data.get("version"), + previous_version=entity_data.get("previous_version"), + source_type=entity_data.get("source_type", "unknown"), + source_uri=entity_data.get("source_uri"), + record_count=entity_data.get("record_count"), + size_bytes=entity_data.get("size_bytes"), + splits=entity_data.get("splits", {}), + attributes=entity_data.get("attributes", {}), + created_at=entity_data.get("created_at", time.time()), + ) + graph.add_entity(entity) + + # Load activities + for activity_data in data.get("activities", {}).values(): + activity = Activity( + id=activity_data["@id"], + activity_type=ActivityType(activity_data["activity_type"]), + name=activity_data["name"], + started_at=activity_data.get("started_at"), + ended_at=activity_data.get("ended_at"), + inputs=activity_data.get("inputs", []), + outputs=activity_data.get("outputs", []), + agent_id=activity_data.get("agent_id"), + parameters=activity_data.get("parameters", {}), + attributes=activity_data.get("attributes", {}), + ) + graph.add_activity(activity) + + # Load agents + for agent_data in data.get("agents", {}).values(): + agent = Agent( + id=agent_data["@id"], + agent_type=AgentType(agent_data["agent_type"]), + name=agent_data["name"], + version=agent_data.get("version"), + parent_agent_id=agent_data.get("parent_agent_id"), + identifier=agent_data.get("identifier"), + attributes=agent_data.get("attributes", {}), + created_at=agent_data.get("created_at", time.time()), + ) + graph.add_agent(agent) + + # Load relationships + for rel_data in data.get("relationships", []): + graph.add_relationship( + relation_type=RelationType(rel_data["type"]), + source_id=rel_data["source"], + target_id=rel_data["target"], + attributes=rel_data.get("attributes", {}), + timestamp=rel_data.get("timestamp"), + ) + + return graph + + # ═══════════════════════════════════════════════════════════════════════════ + # STATISTICS + # ═══════════════════════════════════════════════════════════════════════════ + + @property + def stats(self) -> Dict[str, int]: + """Get graph statistics.""" + return { + "entities": len(self._entities), + "activities": len(self._activities), + "agents": len(self._agents), + "relationships": len(self._relationships), + } + + def __repr__(self) -> str: + stats = self.stats + return ( + f"ProvenanceGraph(name='{self.name}', " + f"entities={stats['entities']}, " + f"activities={stats['activities']}, " + f"relationships={stats['relationships']})" + ) diff --git a/cascade/data/schema.py b/cascade/data/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..1994683a427126e69f13fc910986dff3ad6fcd56 --- /dev/null +++ b/cascade/data/schema.py @@ -0,0 +1,417 @@ +""" +Schema Observer + +Observes and hashes dataset schemas/features. +Works with HuggingFace datasets Features, Pandas DataFrames, and raw dicts. +""" + +import hashlib +import json +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + + +@dataclass +class FieldSchema: + """Schema for a single field/column.""" + name: str + dtype: str # Normalized type name + + # Type details + nullable: bool = True + is_list: bool = False + list_inner_type: Optional[str] = None + + # For ClassLabel + is_categorical: bool = False + categories: Optional[List[str]] = None + num_categories: Optional[int] = None + + # For nested structures + nested_fields: Optional[Dict[str, "FieldSchema"]] = None + + # For arrays/tensors + shape: Optional[tuple] = None + + # Constraints + min_value: Optional[float] = None + max_value: Optional[float] = None + pattern: Optional[str] = None # Regex for strings + + # Metadata + description: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + result = { + "name": self.name, + "dtype": self.dtype, + "nullable": self.nullable, + } + if self.is_list: + result["is_list"] = True + result["list_inner_type"] = self.list_inner_type + if self.is_categorical: + result["is_categorical"] = True + result["categories"] = self.categories + result["num_categories"] = self.num_categories + if self.nested_fields: + result["nested_fields"] = { + k: v.to_dict() for k, v in self.nested_fields.items() + } + if self.shape: + result["shape"] = self.shape + if self.description: + result["description"] = self.description + return result + + def hash(self) -> str: + """Hash this field's structure.""" + content = json.dumps(self.to_dict(), sort_keys=True) + return hashlib.sha256(content.encode()).hexdigest()[:16] + + +@dataclass +class DatasetSchema: + """Complete schema for a dataset.""" + fields: Dict[str, FieldSchema] = field(default_factory=dict) + + # Dataset-level metadata + primary_key: Optional[List[str]] = None + foreign_keys: Dict[str, str] = field(default_factory=dict) # field → target + + # Source info + source_format: Optional[str] = None # arrow, parquet, csv, etc. + + def add_field(self, field_schema: FieldSchema): + """Add a field to the schema.""" + self.fields[field_schema.name] = field_schema + + @property + def field_names(self) -> List[str]: + return list(self.fields.keys()) + + @property + def num_fields(self) -> int: + return len(self.fields) + + def to_dict(self) -> Dict[str, Any]: + return { + "fields": {k: v.to_dict() for k, v in self.fields.items()}, + "primary_key": self.primary_key, + "foreign_keys": self.foreign_keys, + "source_format": self.source_format, + } + + def hash(self) -> str: + """Compute schema hash - identifies structure regardless of content.""" + # Sort fields for deterministic hashing + ordered_fields = sorted(self.fields.keys()) + content = json.dumps({ + "fields": [self.fields[k].to_dict() for k in ordered_fields], + "primary_key": self.primary_key, + }, sort_keys=True) + return hashlib.sha256(content.encode()).hexdigest() + + def diff(self, other: "DatasetSchema") -> Dict[str, Any]: + """Compare two schemas and return differences.""" + added = set(other.field_names) - set(self.field_names) + removed = set(self.field_names) - set(other.field_names) + + modified = {} + for name in set(self.field_names) & set(other.field_names): + if self.fields[name].hash() != other.fields[name].hash(): + modified[name] = { + "old": self.fields[name].to_dict(), + "new": other.fields[name].to_dict(), + } + + return { + "added": list(added), + "removed": list(removed), + "modified": modified, + "compatible": len(removed) == 0 and len(modified) == 0, + } + + +class SchemaObserver: + """ + Observes and extracts schemas from various data sources. + """ + + # Type mapping from various sources to normalized types + TYPE_MAP = { + # Python types + "str": "string", + "int": "int64", + "float": "float64", + "bool": "bool", + "bytes": "binary", + + # NumPy types + "int8": "int8", + "int16": "int16", + "int32": "int32", + "int64": "int64", + "uint8": "uint8", + "uint16": "uint16", + "uint32": "uint32", + "uint64": "uint64", + "float16": "float16", + "float32": "float32", + "float64": "float64", + + # Arrow types + "string": "string", + "large_string": "string", + "binary": "binary", + "large_binary": "binary", + + # HuggingFace special types + "Image": "image", + "Audio": "audio", + "ClassLabel": "categorical", + } + + def observe_hf_dataset(self, dataset) -> DatasetSchema: + """ + Extract schema from HuggingFace Dataset. + + Args: + dataset: A HuggingFace datasets.Dataset or DatasetDict + + Returns: + DatasetSchema with all fields + """ + schema = DatasetSchema(source_format="arrow") + + # Get features (works for both Dataset and DatasetDict) + if hasattr(dataset, 'features'): + features = dataset.features + elif hasattr(dataset, '__iter__'): + # DatasetDict - get features from first split + first_split = next(iter(dataset.values())) + features = first_split.features + else: + raise ValueError(f"Cannot extract features from {type(dataset)}") + + # Parse each feature + for name, feature in features.items(): + field_schema = self._parse_hf_feature(name, feature) + schema.add_field(field_schema) + + return schema + + def _parse_hf_feature(self, name: str, feature) -> FieldSchema: + """Parse a HuggingFace Feature into FieldSchema.""" + # Import here to avoid hard dependency + try: + from datasets import ( + Value, ClassLabel, Sequence, + Array2D, Array3D, Array4D, Array5D, + Image, Audio + ) + except ImportError: + # Fallback for when datasets not installed + return FieldSchema(name=name, dtype="unknown") + + # Value type (primitives) + if isinstance(feature, Value): + return FieldSchema( + name=name, + dtype=self.TYPE_MAP.get(feature.dtype, feature.dtype), + ) + + # ClassLabel (categorical) + if isinstance(feature, ClassLabel): + return FieldSchema( + name=name, + dtype="categorical", + is_categorical=True, + categories=feature.names, + num_categories=feature.num_classes, + ) + + # Sequence (list) + if isinstance(feature, Sequence): + inner = self._parse_hf_feature(f"{name}_inner", feature.feature) + return FieldSchema( + name=name, + dtype="list", + is_list=True, + list_inner_type=inner.dtype, + ) + + # Arrays + if isinstance(feature, (Array2D, Array3D, Array4D, Array5D)): + return FieldSchema( + name=name, + dtype=self.TYPE_MAP.get(feature.dtype, feature.dtype), + shape=feature.shape, + ) + + # Image + if isinstance(feature, Image): + return FieldSchema( + name=name, + dtype="image", + ) + + # Audio + if isinstance(feature, Audio): + return FieldSchema( + name=name, + dtype="audio", + ) + + # Dict/nested structure + if isinstance(feature, dict): + nested = {} + for k, v in feature.items(): + nested[k] = self._parse_hf_feature(k, v) + return FieldSchema( + name=name, + dtype="struct", + nested_fields=nested, + ) + + # Fallback + return FieldSchema( + name=name, + dtype=str(type(feature).__name__), + ) + + def observe_pandas(self, df) -> DatasetSchema: + """ + Extract schema from Pandas DataFrame. + + Args: + df: A pandas DataFrame + + Returns: + DatasetSchema with all fields + """ + schema = DatasetSchema(source_format="pandas") + + for col in df.columns: + dtype = str(df[col].dtype) + normalized = self.TYPE_MAP.get(dtype, dtype) + + # Check for categorical + if dtype == "category": + schema.add_field(FieldSchema( + name=col, + dtype="categorical", + is_categorical=True, + categories=list(df[col].cat.categories), + num_categories=len(df[col].cat.categories), + )) + else: + schema.add_field(FieldSchema( + name=col, + dtype=normalized, + nullable=df[col].isna().any(), + )) + + return schema + + def observe_dict(self, data: Dict[str, Any], sample_size: int = 100) -> DatasetSchema: + """ + Extract schema from a dict of lists (columnar format). + + Args: + data: Dict mapping column names to lists of values + sample_size: Number of values to sample for type inference + + Returns: + DatasetSchema with all fields + """ + schema = DatasetSchema(source_format="dict") + + for col, values in data.items(): + if not values: + schema.add_field(FieldSchema(name=col, dtype="unknown")) + continue + + # Sample values for type inference + sample = values[:sample_size] + types = set(type(v).__name__ for v in sample if v is not None) + + # Determine type + if len(types) == 0: + dtype = "null" + elif len(types) == 1: + dtype = self.TYPE_MAP.get(types.pop(), "unknown") + else: + dtype = "mixed" + + # Check for nulls + nullable = any(v is None for v in sample) + + schema.add_field(FieldSchema( + name=col, + dtype=dtype, + nullable=nullable, + )) + + return schema + + def observe_arrow(self, table) -> DatasetSchema: + """ + Extract schema from PyArrow Table. + + Args: + table: A pyarrow.Table + + Returns: + DatasetSchema with all fields + """ + schema = DatasetSchema(source_format="arrow") + + for field in table.schema: + dtype = str(field.type) + normalized = self.TYPE_MAP.get(dtype, dtype) + + schema.add_field(FieldSchema( + name=field.name, + dtype=normalized, + nullable=field.nullable, + )) + + return schema + + +def hash_content(data, sample_size: int = 10000) -> str: + """ + Compute content hash of dataset. + + For large datasets, samples rows for efficiency. + """ + hasher = hashlib.sha256() + + # Handle dict first (dict also has __iter__ and __len__) + if isinstance(data, dict): + content = json.dumps(data, sort_keys=True, default=str) + hasher.update(content.encode()) + + # Handle list + elif isinstance(data, list): + for item in data[:sample_size]: + item_str = json.dumps(item, sort_keys=True, default=str) + hasher.update(item_str.encode()) + + # Handle HuggingFace Dataset or other iterables with __len__ + elif hasattr(data, '__iter__') and hasattr(data, '__len__'): + # Sample if large + n = len(data) + if n > sample_size: + import random + indices = sorted(random.sample(range(n), sample_size)) + sample = [data[i] for i in indices] + else: + sample = list(data) + + for row in sample: + row_str = json.dumps(row, sort_keys=True, default=str) + hasher.update(row_str.encode()) + + return hasher.hexdigest() diff --git a/cascade/demo.py b/cascade/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..e8865f2e0c72de46a67a579557ecf5b533622a9b --- /dev/null +++ b/cascade/demo.py @@ -0,0 +1,174 @@ +""" +CASCADE-LATTICE Interactive Demo + +Launch the LunarLander demo showcasing: +- cascade.hold: Human-in-the-loop intervention +- cascade.store: Provenance tracking +- Merkle-chained decision records + +Usage: + cascade-demo # Run the demo + python -m cascade.demo # Alternative + +Controls: + [H] HOLD-FREEZE - Pause time, inspect AI decision + [T] HOLD-TAKEOVER - Continue time, YOU control with WASD + [ESC] Release hold, return to AI sovereignty + + In HOLD modes: + [W] Main Engine (thrust up) + [A] Left Engine (rotate) + [D] Right Engine (rotate) + [S] No-op / Accept AI decision +""" + +import sys +import subprocess +from pathlib import Path + + +def check_demo_dependencies(): + """Check if demo dependencies are installed.""" + missing = [] + + try: + import gymnasium + except ImportError: + missing.append("gymnasium") + + try: + import pygame + except ImportError: + missing.append("pygame") + + try: + import stable_baselines3 + except ImportError: + missing.append("stable-baselines3") + + try: + import box2d + except ImportError: + missing.append("box2d-py") + + return missing + + +def main(): + """Launch the interactive CASCADE-LATTICE demo.""" + print(""" +╔═══════════════════════════════════════════════════════════════════════════════╗ +║ ║ +║ ██████╗ █████╗ ███████╗ ██████╗ █████╗ ██████╗ ███████╗ ║ +║ ██╔════╝██╔══██╗██╔════╝██╔════╝██╔══██╗██╔══██╗██╔════╝ ║ +║ ██║ ███████║███████╗██║ ███████║██║ ██║█████╗ ║ +║ ██║ ██╔══██║╚════██║██║ ██╔══██║██║ ██║██╔══╝ ║ +║ ╚██████╗██║ ██║███████║╚██████╗██║ ██║██████╔╝███████╗ ║ +║ ╚═════╝╚═╝ ╚═╝╚══════╝ ╚═════╝╚═╝ ╚═╝╚═════╝ ╚══════╝ ║ +║ ║ +║ LATTICE DEMO - Sovereign Neural Internetwork Control ║ +║ ║ +╚═══════════════════════════════════════════════════════════════════════════════╝ + """) + + # Check dependencies + missing = check_demo_dependencies() + if missing: + print(f"[!] Missing demo dependencies: {', '.join(missing)}") + print() + print(" Install with:") + print(" pip install cascade-lattice[demo]") + print() + print(" Or manually:") + print(f" pip install {' '.join(missing)}") + sys.exit(1) + + # Check for rl-zoo3 (needed for model download) + try: + import rl_zoo3 + except ImportError: + print("[!] Missing rl-zoo3 (needed for pretrained model)") + print(" pip install rl-zoo3") + sys.exit(1) + + print("[CASCADE] Starting LunarLander demo...") + print() + print("Controls:") + print(" [H] HOLD-FREEZE - Pause time, inspect AI decision") + print(" [T] HOLD-TAKEOVER - Continue time, YOU control with WASD") + print(" [ESC] Release hold / Quit") + print() + print("In HOLD modes:") + print(" [W] Main Engine [A] Left Engine [D] Right Engine") + print(" [S] Accept AI choice / No-op") + print() + + # Run the demo + demo_path = Path(__file__).parent.parent / "examples" / "sovereign_lattice_eval.py" + + if not demo_path.exists(): + # Try installed package location + import cascade + package_dir = Path(cascade.__file__).parent + demo_path = package_dir.parent / "examples" / "sovereign_lattice_eval.py" + + if not demo_path.exists(): + # Fallback: run inline demo + print("[!] Demo file not found. Running inline version...") + _run_inline_demo() + return + + # Run the demo script + subprocess.run([sys.executable, str(demo_path)]) + + +def _run_inline_demo(): + """Minimal inline demo if main file not found.""" + import gymnasium as gym + import numpy as np + + from cascade import init + from cascade.hold import Hold + from cascade.store import observe + + init(project="cascade_demo") + hold = Hold.get() + + print("[CASCADE] Running minimal demo (install full package for GUI)") + print() + + env = gym.make("LunarLander-v3") + obs, _ = env.reset() + + for step in range(100): + # Random policy for minimal demo + action_probs = np.array([0.25, 0.25, 0.25, 0.25]) + + resolution = hold.yield_point( + action_probs=action_probs, + value=0.0, + observation={"state": obs.tolist()[:4]}, + brain_id="random_demo", + action_labels=["NOOP", "LEFT", "MAIN", "RIGHT"], + blocking=False + ) + + obs, reward, term, trunc, _ = env.step(resolution.action) + + observe("demo", { + "step": step, + "action": int(resolution.action), + "reward": float(reward), + "merkle": resolution.merkle_root, + }, sync=False) + + if term or trunc: + print(f"[CASCADE] Episode ended at step {step}") + break + + env.close() + print("[CASCADE] Demo complete. Check ~/.cascade/lattice for provenance data.") + + +if __name__ == "__main__": + main() diff --git a/cascade/demo_sdk.py b/cascade/demo_sdk.py new file mode 100644 index 0000000000000000000000000000000000000000..509441240e1cbc748a8637226d1cc187a8ed6099 --- /dev/null +++ b/cascade/demo_sdk.py @@ -0,0 +1,114 @@ +""" +CASCADE SDK Demo - Shows automatic observation of calls. + +Run: python -m cascade.demo_sdk +""" + +import os +import sys + +# Add cascade to path if needed +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def demo_manual_observation(): + """Demo manual observation without any provider installed.""" + print("=" * 60) + print("CASCADE SDK Demo - Manual Observation") + print("=" * 60) + + import cascade + from cascade.sdk import CascadeSDK + + # Initialize with verbose mode + sdk = CascadeSDK() + sdk.init(emit_async=False, verbose=True) + + print("\n[1] Simulating an OpenAI call...") + sdk.observe( + model_id="openai/gpt-4", + input_data="What is the capital of France?", + output_data="The capital of France is Paris.", + metrics={"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, + context={"provider": "openai", "endpoint": "chat.completions"} + ) + + print("\n[2] Simulating an Anthropic call...") + sdk.observe( + model_id="anthropic/claude-3-opus-20240229", + input_data="Explain quantum entanglement simply.", + output_data="Quantum entanglement is when two particles become connected...", + metrics={"input_tokens": 6, "output_tokens": 45}, + context={"provider": "anthropic", "endpoint": "messages"} + ) + + print("\n[3] Simulating an Ollama local call...") + sdk.observe( + model_id="ollama/llama2:7b", + input_data="Write a haiku about coding.", + output_data="Fingers on keyboard\nLogic flows like mountain stream\nBugs become features", + metrics={"eval_count": 20, "eval_duration": 1.5}, + context={"provider": "ollama", "endpoint": "generate"} + ) + + print("\n" + "=" * 60) + print("Observations saved to lattice/observations/") + print("=" * 60) + + # Show what was saved + from cascade.observation import ObservationManager + manager = ObservationManager() + stats = manager.get_stats() + print(f"\nTotal observations: {stats['total_observations']}") + print(f"Model observations: {stats['model_observations']}") + print(f"Unique models: {stats['unique_models']}") + + +def demo_auto_patch(): + """Demo auto-patching (requires providers to be installed).""" + print("\n" + "=" * 60) + print("CASCADE Auto-Patch Demo") + print("=" * 60) + + import cascade + + # This patches all installed providers + cascade.init(verbose=True) + + print("\nPatched providers. Now any call will emit receipts.") + print("Example usage:") + print(""" + import cascade + cascade.init() + + # OpenAI (if installed) + import openai + client = openai.OpenAI() + response = client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello!"}] + ) + # ^^^ Receipt automatically emitted to lattice + + # Anthropic (if installed) + import anthropic + client = anthropic.Anthropic() + response = client.messages.create( + model="claude-3-opus-20240229", + max_tokens=100, + messages=[{"role": "user", "content": "Hello!"}] + ) + # ^^^ Receipt automatically emitted to lattice + + # Ollama (if installed) + import ollama + response = ollama.chat(model="llama2", messages=[ + {"role": "user", "content": "Hello!"} + ]) + # ^^^ Receipt automatically emitted to lattice + """) + + +if __name__ == "__main__": + demo_manual_observation() + demo_auto_patch() diff --git a/cascade/export/__init__.py b/cascade/export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5accd8b4d17991c54b6d0a7d64b81949e2b9e3 --- /dev/null +++ b/cascade/export/__init__.py @@ -0,0 +1,23 @@ +""" +CASCADE Export Module - Tableau and BI Integration +""" + +from .tableau_export import ( + export_for_tableau, + export_events_csv, + export_chains_csv, + export_metrics_csv, + export_hold_events_csv, + export_causation_graph_csv, + TableauExporter, +) + +__all__ = [ + "export_for_tableau", + "export_events_csv", + "export_chains_csv", + "export_metrics_csv", + "export_hold_events_csv", + "export_causation_graph_csv", + "TableauExporter", +] diff --git a/cascade/export/tableau_export.py b/cascade/export/tableau_export.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d95fb4eb09b708b44aeb818b4eef32fbd09820 --- /dev/null +++ b/cascade/export/tableau_export.py @@ -0,0 +1,598 @@ +""" +CASCADE → Tableau Export Pipeline + +Exports Cascade data in Tableau-friendly formats: +- CSV files (universal) +- Hyper files (native Tableau, optional) + +Usage: + from cascade.export import export_for_tableau + + # Export all data to a directory + export_for_tableau("./tableau_data") + + # Then in Tableau: Connect → Text File → select CSVs +""" + +import csv +import json +import os +from pathlib import Path +from datetime import datetime +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, asdict + +# Try to import Hyper API (optional) +try: + from tableauhyperapi import ( + HyperProcess, Telemetry, Connection, CreateMode, + TableDefinition, SqlType, TableName, Inserter + ) + HAS_HYPER = True +except ImportError: + HAS_HYPER = False + + +@dataclass +class EventRow: + """Flattened event for Tableau.""" + event_id: str + timestamp: float + timestamp_iso: str + component: str + event_type: str + data_json: str + # Extracted common fields + loss: Optional[float] = None + accuracy: Optional[float] = None + learning_rate: Optional[float] = None + epoch: Optional[int] = None + step: Optional[int] = None + tokens: Optional[int] = None + latency_ms: Optional[float] = None + error_message: Optional[str] = None + + +@dataclass +class ChainRow: + """Flattened provenance chain for Tableau.""" + session_id: str + model_id: str + model_hash: str + input_hash: str + output_hash: Optional[str] + merkle_root: str + created_at: float + created_at_iso: str + record_count: int + external_links_count: int + is_verified: bool + + +@dataclass +class HoldEventRow: + """Flattened HOLD event for Tableau.""" + hold_id: str + timestamp: float + timestamp_iso: str + brain_id: str + state: str # PENDING, ACCEPTED, OVERRIDDEN, TIMEOUT + ai_choice: int + ai_confidence: float + final_action: int + was_override: bool + hold_duration_sec: float + value_estimate: float + action_count: int + override_source: Optional[str] = None + + +@dataclass +class CausationEdgeRow: + """Flattened causation link for Tableau.""" + link_id: str + from_event_id: str + to_event_id: str + causation_type: str # temporal, correlation, threshold, direct + strength: float + timestamp: float + timestamp_iso: str + + +@dataclass +class MetricRow: + """Time-series metric for Tableau.""" + timestamp: float + timestamp_iso: str + metric_name: str + metric_value: float + category: str # TRAINING_DYNAMICS, GRADIENT_HEALTH, etc. + component: str + is_anomaly: bool + anomaly_severity: Optional[str] = None + + +def _ts_to_iso(ts: float) -> str: + """Convert Unix timestamp to ISO string.""" + try: + return datetime.fromtimestamp(ts).isoformat() + except: + return "" + + +def _extract_metric_fields(data: Dict) -> Dict[str, Any]: + """Extract common metric fields from event data.""" + return { + "loss": data.get("loss"), + "accuracy": data.get("accuracy") or data.get("acc"), + "learning_rate": data.get("learning_rate") or data.get("lr"), + "epoch": data.get("epoch"), + "step": data.get("step") or data.get("iter"), + "tokens": data.get("tokens") or data.get("total_tokens"), + "latency_ms": data.get("latency_ms") or data.get("latency"), + "error_message": data.get("error") or data.get("message"), + } + + +class TableauExporter: + """ + Export Cascade data for Tableau visualization. + + Creates a directory with CSV files ready for Tableau import: + - events.csv: All observed events + - chains.csv: Provenance chains + - hold_events.csv: HOLD protocol events + - causation_edges.csv: Graph edges for relationship diagrams + - metrics_timeseries.csv: Metrics over time + + Example: + exporter = TableauExporter() + exporter.add_events(events) + exporter.add_chains(chains) + exporter.export("./tableau_data") + """ + + def __init__(self): + self.events: List[EventRow] = [] + self.chains: List[ChainRow] = [] + self.hold_events: List[HoldEventRow] = [] + self.causation_edges: List[CausationEdgeRow] = [] + self.metrics: List[MetricRow] = [] + + def add_event(self, event) -> None: + """Add a Cascade Event.""" + data = event.data if hasattr(event, 'data') else {} + extracted = _extract_metric_fields(data) + + row = EventRow( + event_id=event.event_id, + timestamp=event.timestamp, + timestamp_iso=_ts_to_iso(event.timestamp), + component=event.component, + event_type=event.event_type, + data_json=json.dumps(data), + **extracted + ) + self.events.append(row) + + def add_events(self, events) -> None: + """Add multiple events.""" + for e in events: + self.add_event(e) + + def add_chain(self, chain, is_verified: bool = True) -> None: + """Add a ProvenanceChain.""" + row = ChainRow( + session_id=chain.session_id, + model_id=chain.model_id, + model_hash=chain.model_hash, + input_hash=chain.input_hash, + output_hash=chain.output_hash, + merkle_root=chain.merkle_root or "", + created_at=chain.created_at, + created_at_iso=_ts_to_iso(chain.created_at), + record_count=len(chain.records), + external_links_count=len(chain.external_roots), + is_verified=is_verified, + ) + self.chains.append(row) + + def add_chains(self, chains) -> None: + """Add multiple chains.""" + for c in chains: + self.add_chain(c) + + def add_hold_event(self, hold_point, resolution) -> None: + """Add a HOLD event with its resolution.""" + import numpy as np + + probs = hold_point.action_probs + if isinstance(probs, np.ndarray): + ai_choice = int(np.argmax(probs)) + ai_confidence = float(np.max(probs)) + action_count = len(probs) + else: + ai_choice = 0 + ai_confidence = 0.0 + action_count = 0 + + row = HoldEventRow( + hold_id=getattr(hold_point, 'hold_id', f"hold_{hold_point.timestamp}"), + timestamp=hold_point.timestamp if hasattr(hold_point, 'timestamp') else 0, + timestamp_iso=_ts_to_iso(hold_point.timestamp) if hasattr(hold_point, 'timestamp') else "", + brain_id=hold_point.brain_id, + state=resolution.state.value if hasattr(resolution.state, 'value') else str(resolution.state), + ai_choice=ai_choice, + ai_confidence=ai_confidence, + final_action=resolution.action, + was_override=resolution.was_override, + hold_duration_sec=resolution.hold_duration if hasattr(resolution, 'hold_duration') else 0, + value_estimate=hold_point.value, + action_count=action_count, + override_source=resolution.override_source if hasattr(resolution, 'override_source') else None, + ) + self.hold_events.append(row) + + def add_causation_link(self, link) -> None: + """Add a causation graph edge.""" + row = CausationEdgeRow( + link_id=link.link_id if hasattr(link, 'link_id') else f"{link.from_event}_{link.to_event}", + from_event_id=link.from_event, + to_event_id=link.to_event, + causation_type=link.causation_type, + strength=link.strength, + timestamp=link.timestamp if hasattr(link, 'timestamp') else 0, + timestamp_iso=_ts_to_iso(link.timestamp) if hasattr(link, 'timestamp') else "", + ) + self.causation_edges.append(row) + + def add_causation_links(self, links) -> None: + """Add multiple causation links.""" + for link in links: + self.add_causation_link(link) + + def add_metric(self, name: str, value: float, timestamp: float, + category: str = "OTHER", component: str = "default", + is_anomaly: bool = False, anomaly_severity: str = None) -> None: + """Add a time-series metric point.""" + row = MetricRow( + timestamp=timestamp, + timestamp_iso=_ts_to_iso(timestamp), + metric_name=name, + metric_value=value, + category=category, + component=component, + is_anomaly=is_anomaly, + anomaly_severity=anomaly_severity, + ) + self.metrics.append(row) + + def add_metrics_from_event(self, event, category_map: Dict[str, str] = None) -> None: + """Extract and add all metrics from an event.""" + if category_map is None: + category_map = { + "loss": "TRAINING_DYNAMICS", + "accuracy": "TRAINING_DYNAMICS", + "lr": "TRAINING_DYNAMICS", + "learning_rate": "TRAINING_DYNAMICS", + "grad_norm": "GRADIENT_HEALTH", + "weight_norm": "WEIGHT_DYNAMICS", + "tokens": "MEMORY_COMPUTE", + "latency": "MEMORY_COMPUTE", + } + + data = event.data if hasattr(event, 'data') else {} + for key, value in data.items(): + if isinstance(value, (int, float)) and not isinstance(value, bool): + self.add_metric( + name=key, + value=float(value), + timestamp=event.timestamp, + category=category_map.get(key, "OTHER"), + component=event.component, + ) + + def _write_csv(self, path: Path, rows: List, fieldnames: List[str]) -> None: + """Write rows to CSV.""" + with open(path, 'w', newline='', encoding='utf-8') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow(asdict(row) if hasattr(row, '__dataclass_fields__') else row) + + def export(self, output_dir: str) -> Dict[str, str]: + """ + Export all data to CSV files. + + Args: + output_dir: Directory to write CSV files + + Returns: + Dict mapping data type to file path + """ + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + files = {} + + # Events + if self.events: + events_path = output_path / "events.csv" + self._write_csv(events_path, self.events, list(EventRow.__dataclass_fields__.keys())) + files["events"] = str(events_path) + print(f"✓ Exported {len(self.events)} events to {events_path}") + + # Chains + if self.chains: + chains_path = output_path / "chains.csv" + self._write_csv(chains_path, self.chains, list(ChainRow.__dataclass_fields__.keys())) + files["chains"] = str(chains_path) + print(f"✓ Exported {len(self.chains)} chains to {chains_path}") + + # HOLD events + if self.hold_events: + hold_path = output_path / "hold_events.csv" + self._write_csv(hold_path, self.hold_events, list(HoldEventRow.__dataclass_fields__.keys())) + files["hold_events"] = str(hold_path) + print(f"✓ Exported {len(self.hold_events)} HOLD events to {hold_path}") + + # Causation edges + if self.causation_edges: + edges_path = output_path / "causation_edges.csv" + self._write_csv(edges_path, self.causation_edges, list(CausationEdgeRow.__dataclass_fields__.keys())) + files["causation_edges"] = str(edges_path) + print(f"✓ Exported {len(self.causation_edges)} causation edges to {edges_path}") + + # Metrics time series + if self.metrics: + metrics_path = output_path / "metrics_timeseries.csv" + self._write_csv(metrics_path, self.metrics, list(MetricRow.__dataclass_fields__.keys())) + files["metrics"] = str(metrics_path) + print(f"✓ Exported {len(self.metrics)} metric points to {metrics_path}") + + # Write a manifest + manifest_path = output_path / "manifest.json" + manifest = { + "exported_at": datetime.now().isoformat(), + "files": files, + "counts": { + "events": len(self.events), + "chains": len(self.chains), + "hold_events": len(self.hold_events), + "causation_edges": len(self.causation_edges), + "metrics": len(self.metrics), + } + } + with open(manifest_path, 'w') as f: + json.dump(manifest, f, indent=2) + + print(f"\n📊 Tableau export complete: {output_path}") + print(f" Open Tableau → Connect → Text File → Select CSVs") + + return files + + def export_hyper(self, output_path: str) -> Optional[str]: + """ + Export to Tableau Hyper format (native, fastest). + + Requires: pip install tableauhyperapi + """ + if not HAS_HYPER: + print("⚠️ Hyper API not installed. Run: pip install tableauhyperapi") + return None + + hyper_path = Path(output_path) + + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hyper: + with Connection(hyper.endpoint, str(hyper_path), CreateMode.CREATE_AND_REPLACE) as conn: + + # Create events table + if self.events: + events_table = TableDefinition( + TableName("events"), + [ + ("event_id", SqlType.text()), + ("timestamp", SqlType.double()), + ("timestamp_iso", SqlType.text()), + ("component", SqlType.text()), + ("event_type", SqlType.text()), + ("loss", SqlType.double()), + ("accuracy", SqlType.double()), + ("tokens", SqlType.int()), + ] + ) + conn.catalog.create_table(events_table) + + with Inserter(conn, events_table) as inserter: + for e in self.events: + inserter.add_row([ + e.event_id, e.timestamp, e.timestamp_iso, + e.component, e.event_type, + e.loss, e.accuracy, e.tokens + ]) + inserter.execute() + + print(f"✓ Exported Hyper file: {hyper_path}") + return str(hyper_path) + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + +def export_for_tableau(output_dir: str = "./tableau_export", + include_sample_data: bool = True) -> Dict[str, str]: + """ + One-line export of all Cascade data for Tableau. + + Args: + output_dir: Where to write CSV files + include_sample_data: Generate sample data if no real data + + Returns: + Dict of exported file paths + """ + exporter = TableauExporter() + + # Try to load real data from Cascade store + try: + from cascade.store import query, stats + from cascade.observation import ObservationManager + + # Get observations + manager = ObservationManager() + observations = manager.get_recent(limit=1000) + + for obs in observations: + # Create mock event from observation + class MockEvent: + def __init__(self, o): + self.event_id = o.get('cid', '') + self.timestamp = o.get('timestamp', 0) + self.component = o.get('model_id', 'unknown') + self.event_type = 'inference' + self.data = o.get('data', {}) + + exporter.add_event(MockEvent(obs)) + exporter.add_metrics_from_event(MockEvent(obs)) + + print(f"Loaded {len(observations)} observations from Cascade store") + + except Exception as e: + print(f"Note: Could not load Cascade store ({e})") + if include_sample_data: + print("Generating sample data for demo...") + _add_sample_data(exporter) + + return exporter.export(output_dir) + + +def _add_sample_data(exporter: TableauExporter) -> None: + """Add sample data for demonstration.""" + import time + import random + + base_time = time.time() - 3600 # 1 hour ago + + # Sample events + models = ["gpt-4", "claude-3-opus", "llama-3-8b", "mistral-7b"] + event_types = ["inference", "training_step", "error", "checkpoint"] + + for i in range(200): + class SampleEvent: + def __init__(self, idx): + self.event_id = f"evt_{idx:06d}" + self.timestamp = base_time + (idx * 18) # 18 sec apart + self.component = random.choice(models) + self.event_type = random.choice(event_types) + self.data = { + "loss": 2.5 - (idx * 0.01) + random.uniform(-0.1, 0.1), + "accuracy": min(0.95, 0.5 + (idx * 0.002) + random.uniform(-0.02, 0.02)), + "tokens": random.randint(100, 2000), + "latency_ms": random.uniform(50, 500), + "step": idx, + } + + event = SampleEvent(i) + exporter.add_event(event) + exporter.add_metrics_from_event(event) + + # Sample HOLD events + for i in range(20): + class SampleHoldPoint: + def __init__(self, idx): + import numpy as np + self.hold_id = f"hold_{idx:04d}" + self.timestamp = base_time + (idx * 180) + self.brain_id = random.choice(models) + self.action_probs = np.random.dirichlet([1, 1, 1, 1]) + self.value = random.uniform(0.3, 0.9) + + class SampleResolution: + def __init__(self, override=False): + self.state = type('State', (), {'value': 'OVERRIDDEN' if override else 'ACCEPTED'})() + self.action = random.randint(0, 3) + self.was_override = override + self.hold_duration = random.uniform(0.5, 10.0) + self.override_source = "human" if override else None + + hold = SampleHoldPoint(i) + resolution = SampleResolution(override=random.random() < 0.25) + exporter.add_hold_event(hold, resolution) + + # Sample causation edges + for i in range(50): + class SampleLink: + def __init__(self, idx): + self.link_id = f"link_{idx:04d}" + self.from_event = f"evt_{idx:06d}" + self.to_event = f"evt_{idx+1:06d}" + self.causation_type = random.choice(["temporal", "correlation", "threshold", "direct"]) + self.strength = random.uniform(0.5, 1.0) + self.timestamp = base_time + (idx * 18) + + exporter.add_causation_link(SampleLink(i)) + + # Sample chains + for i in range(10): + class SampleChain: + def __init__(self, idx): + self.session_id = f"session_{idx:04d}" + self.model_id = random.choice(models) + self.model_hash = f"{random.randint(0, 0xFFFFFFFF):08x}" + self.input_hash = f"{random.randint(0, 0xFFFFFFFF):08x}" + self.output_hash = f"{random.randint(0, 0xFFFFFFFF):08x}" + self.merkle_root = f"{random.randint(0, 0xFFFFFFFFFFFFFFFF):016x}" + self.created_at = base_time + (idx * 360) + self.records = [None] * random.randint(5, 50) + self.external_roots = [f"root_{j}" for j in range(random.randint(0, 3))] + + exporter.add_chain(SampleChain(i)) + + +def export_events_csv(events, output_path: str) -> str: + """Export events to CSV.""" + exporter = TableauExporter() + exporter.add_events(events) + files = exporter.export(str(Path(output_path).parent)) + return files.get("events", "") + + +def export_chains_csv(chains, output_path: str) -> str: + """Export chains to CSV.""" + exporter = TableauExporter() + exporter.add_chains(chains) + files = exporter.export(str(Path(output_path).parent)) + return files.get("chains", "") + + +def export_metrics_csv(events, output_path: str) -> str: + """Export metrics time series to CSV.""" + exporter = TableauExporter() + for e in events: + exporter.add_metrics_from_event(e) + files = exporter.export(str(Path(output_path).parent)) + return files.get("metrics", "") + + +def export_hold_events_csv(hold_pairs, output_path: str) -> str: + """Export HOLD events to CSV. hold_pairs = [(hold_point, resolution), ...]""" + exporter = TableauExporter() + for hold, res in hold_pairs: + exporter.add_hold_event(hold, res) + files = exporter.export(str(Path(output_path).parent)) + return files.get("hold_events", "") + + +def export_causation_graph_csv(links, output_path: str) -> str: + """Export causation edges to CSV.""" + exporter = TableauExporter() + exporter.add_causation_links(links) + files = exporter.export(str(Path(output_path).parent)) + return files.get("causation_edges", "") + + +if __name__ == "__main__": + # Quick test + print("Exporting sample data for Tableau...") + export_for_tableau("./tableau_export", include_sample_data=True) diff --git a/cascade/forensics/__init__.py b/cascade/forensics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..149d3b61b63953858552b9e778a9a612e694b6bf --- /dev/null +++ b/cascade/forensics/__init__.py @@ -0,0 +1,53 @@ +""" +CASCADE Forensics - Read the Ghost in the Data + +Every dataset is a confession. It remembers what happened to it. +This module reads those memories. + +GHOST LOG: Inferred processing history from data artifacts +SKELETON: Probable system architecture +DNA: Technology fingerprints +SOUL: Behavioral predictions + +Usage: + from cascade.forensics import DataForensics + + forensics = DataForensics() + report = forensics.analyze(dataframe) + + print(report.ghost_log) # Inferred operations + print(report.skeleton) # System architecture + print(report.fingerprints) # Technology hints +""" + +from cascade.forensics.analyzer import ( + DataForensics, + ForensicsReport, + GhostLog, + InferredOperation, +) + +from cascade.forensics.artifacts import ( + ArtifactDetector, + TimestampArtifacts, + IDPatternArtifacts, + TextArtifacts, + NumericArtifacts, + NullPatternArtifacts, + SchemaArtifacts, +) + +from cascade.forensics.fingerprints import ( + TechFingerprinter, + Fingerprint, +) + +__all__ = [ + "DataForensics", + "ForensicsReport", + "GhostLog", + "InferredOperation", + "ArtifactDetector", + "TechFingerprinter", + "Fingerprint", +] diff --git a/cascade/forensics/analyzer.py b/cascade/forensics/analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..156c39e47b114691ee07736d22c141af274787e5 --- /dev/null +++ b/cascade/forensics/analyzer.py @@ -0,0 +1,464 @@ +""" +CASCADE Forensics - Main Analyzer + +The data remembers. This module reads those memories. + +Generates: +- GHOST LOG: Inferred sequence of operations +- SKELETON: Probable system architecture +- DNA: Technology fingerprints +- SOUL: Behavioral predictions +""" + +import hashlib +import json +import time +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional +from collections import OrderedDict + +from cascade.forensics.artifacts import ( + Artifact, ArtifactDetector, + TimestampArtifacts, IDPatternArtifacts, TextArtifacts, + NumericArtifacts, NullPatternArtifacts, SchemaArtifacts, +) +from cascade.forensics.fingerprints import TechFingerprinter, Fingerprint + + +@dataclass +class InferredOperation: + """A single inferred operation from the ghost log.""" + sequence: int + operation: str + description: str + confidence: float + evidence: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "seq": self.sequence, + "op": self.operation, + "desc": self.description, + "confidence": self.confidence, + "evidence": self.evidence, + } + + +@dataclass +class GhostLog: + """ + Inferred processing history - the ghost of the system. + + This is a reconstruction of what PROBABLY happened + based on artifacts left in the data. + """ + operations: List[InferredOperation] = field(default_factory=list) + + # Provenance + analysis_timestamp: float = field(default_factory=time.time) + data_hash: str = "" + ghost_hash: str = "" + + def add_operation(self, op: str, desc: str, confidence: float, evidence: List[str] = None): + """Add an inferred operation to the ghost log.""" + self.operations.append(InferredOperation( + sequence=len(self.operations) + 1, + operation=op, + description=desc, + confidence=confidence, + evidence=evidence or [], + )) + + def finalize(self) -> str: + """Compute hash of the ghost log for provenance.""" + content = json.dumps([op.to_dict() for op in self.operations], sort_keys=True) + self.ghost_hash = hashlib.sha256(content.encode()).hexdigest()[:16] + return self.ghost_hash + + def to_dict(self) -> Dict[str, Any]: + return { + "operations": [op.to_dict() for op in self.operations], + "analysis_timestamp": self.analysis_timestamp, + "data_hash": self.data_hash, + "ghost_hash": self.ghost_hash, + } + + def to_narrative(self) -> str: + """Generate human-readable narrative of inferred processing.""" + if not self.operations: + return "No processing artifacts detected." + + lines = ["## Ghost Log - Inferred Processing History\n"] + lines.append("*Based on artifacts left in the data, this is what probably happened:*\n") + + for op in self.operations: + conf_str = "●" * int(op.confidence * 5) + "○" * (5 - int(op.confidence * 5)) + lines.append(f"**{op.sequence}. {op.operation}** [{conf_str}]") + lines.append(f" {op.description}") + if op.evidence: + lines.append(f" *Evidence: {', '.join(op.evidence[:3])}*") + lines.append("") + + return "\n".join(lines) + + +@dataclass +class ForensicsReport: + """Complete forensics analysis report.""" + + # Artifacts detected + artifacts: List[Artifact] = field(default_factory=list) + + # Inferred processing + ghost_log: GhostLog = field(default_factory=GhostLog) + + # Technology fingerprints + fingerprints: List[Fingerprint] = field(default_factory=list) + + # Synthesized architecture + likely_stack: Dict[str, Any] = field(default_factory=dict) + + # Security concerns + security_concerns: List[Dict[str, Any]] = field(default_factory=list) + + # Metadata + analysis_timestamp: float = field(default_factory=time.time) + row_count: int = 0 + column_count: int = 0 + data_hash: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "artifacts": [a.to_dict() for a in self.artifacts], + "ghost_log": self.ghost_log.to_dict(), + "fingerprints": [f.to_dict() for f in self.fingerprints], + "likely_stack": self.likely_stack, + "security_concerns": self.security_concerns, + "metadata": { + "timestamp": self.analysis_timestamp, + "rows": self.row_count, + "columns": self.column_count, + "data_hash": self.data_hash, + } + } + + def summary(self) -> Dict[str, Any]: + """Generate summary for display.""" + return { + "artifacts_found": len(self.artifacts), + "operations_inferred": len(self.ghost_log.operations), + "technologies_identified": len(self.fingerprints), + "security_concerns": len(self.security_concerns), + "top_fingerprints": [f.technology for f in self.fingerprints[:5]], + "data_hash": self.data_hash, + "ghost_hash": self.ghost_log.ghost_hash, + } + + +class DataForensics: + """ + Main forensics analyzer. + + Usage: + forensics = DataForensics() + report = forensics.analyze(df) + + print(report.ghost_log.to_narrative()) + print(report.likely_stack) + """ + + def __init__(self): + self.detectors = [ + TimestampArtifacts(), + IDPatternArtifacts(), + TextArtifacts(), + NumericArtifacts(), + NullPatternArtifacts(), + SchemaArtifacts(), + ] + self.fingerprinter = TechFingerprinter() + + def analyze(self, df) -> ForensicsReport: + """ + Analyze a dataframe for processing artifacts. + + Args: + df: Pandas DataFrame to analyze + + Returns: + ForensicsReport with all findings + """ + report = ForensicsReport() + report.row_count = len(df) + report.column_count = len(df.columns) + + # Compute data hash + try: + # Sample hash for large datasets + if len(df) > 10000: + sample = df.sample(10000, random_state=42) + else: + sample = df + content = sample.to_json() + report.data_hash = hashlib.sha256(content.encode()).hexdigest()[:16] + except: + report.data_hash = "unknown" + + # Run all detectors + all_artifacts = [] + + for detector in self.detectors: + try: + # Some detectors analyze all columns at once + if hasattr(detector, 'detect_all'): + artifacts = detector.detect_all(df) + all_artifacts.extend(artifacts) + + # Column-by-column analysis + for col in df.columns: + artifacts = detector.detect(df, col) + all_artifacts.extend(artifacts) + except Exception as e: + # Don't let one detector crash the whole analysis + pass + + report.artifacts = all_artifacts + + # Build ghost log from artifacts + report.ghost_log = self._build_ghost_log(all_artifacts, df) + report.ghost_log.data_hash = report.data_hash + report.ghost_log.finalize() + + # Generate technology fingerprints + report.fingerprints = self.fingerprinter.analyze(all_artifacts) + report.likely_stack = self.fingerprinter.get_likely_stack() + report.security_concerns = self.fingerprinter.get_security_concerns() + + return report + + def _build_ghost_log(self, artifacts: List[Artifact], df) -> GhostLog: + """ + Build inferred processing history from artifacts. + + This is where we reconstruct the sequence of operations + that probably created this data. + """ + ghost = GhostLog() + + # Group artifacts by type for logical ordering + by_type = {} + for a in artifacts: + if a.artifact_type not in by_type: + by_type[a.artifact_type] = [] + by_type[a.artifact_type].append(a) + + # Infer operations in logical order + + # 1. Data sourcing (schema artifacts come first) + if "framework_fingerprint" in by_type: + for a in by_type["framework_fingerprint"]: + ghost.add_operation( + "DATA_SOURCE", + f"Data originated from {a.details.get('framework', 'database')}: {a.evidence}", + a.confidence, + [a.evidence] + ) + + if "naming_convention" in by_type: + for a in by_type["naming_convention"]: + ghost.add_operation( + "SCHEMA_ORIGIN", + f"Schema follows {a.details.get('convention', 'unknown')} convention", + a.confidence, + [a.evidence] + ) + + # 2. Merging (if multiple sources detected) + if "mixed_conventions" in by_type or "id_prefix" in by_type: + ghost.add_operation( + "DATA_MERGE", + "Multiple data sources were merged together", + 0.75, + [a.evidence for a in by_type.get("mixed_conventions", []) + by_type.get("id_prefix", [])] + ) + + # 3. ID generation + if "uuid_version" in by_type: + for a in by_type["uuid_version"]: + ghost.add_operation( + "ID_GENERATION", + f"IDs generated using {a.details.get('meaning', 'UUID')}", + a.confidence, + [a.evidence] + ) + + if "hash_id" in by_type: + for a in by_type["hash_id"]: + ghost.add_operation( + "ID_GENERATION", + f"IDs are {a.details.get('probable_algorithm', 'hash')}-based (content-addressed)", + a.confidence, + [a.evidence] + ) + + # 4. Processing / Transformation + if "case_normalization" in by_type: + for a in by_type["case_normalization"]: + ghost.add_operation( + "TEXT_NORMALIZATION", + f"Text converted to {a.details.get('case', 'normalized')} case", + a.confidence, + [a.evidence] + ) + + if "whitespace_trimming" in by_type: + ghost.add_operation( + "TEXT_CLEANING", + "Whitespace trimmed from text fields", + 0.70, + [a.evidence for a in by_type["whitespace_trimming"]] + ) + + if "truncation" in by_type: + for a in by_type["truncation"]: + ghost.add_operation( + "FIELD_TRUNCATION", + f"Text truncated at {a.details.get('max_length', '?')} characters", + a.confidence, + [a.evidence] + ) + + if "numeric_rounding" in by_type: + for a in by_type["numeric_rounding"]: + ghost.add_operation( + "NUMERIC_ROUNDING", + f"Numbers rounded: {a.evidence}", + a.confidence, + [a.evidence] + ) + + # 5. Filtering / Deletion + if "sequential_id_gaps" in by_type: + for a in by_type["sequential_id_gaps"]: + gap_ratio = a.details.get('gap_ratio', 0) + ghost.add_operation( + "RECORD_FILTERING", + f"~{gap_ratio*100:.0f}% of records were filtered or deleted", + a.confidence, + [a.evidence] + ) + + if "hard_cutoff" in by_type: + for a in by_type["hard_cutoff"]: + ghost.add_operation( + "VALUE_CAPPING", + f"Values capped at {a.details.get('cutoff', '?')}", + a.confidence, + [a.evidence] + ) + + # 6. Batch processing patterns + if "timestamp_rounding" in by_type: + for a in by_type["timestamp_rounding"]: + ghost.add_operation( + "BATCH_PROCESSING", + f"Data processed in batches: {a.evidence}", + a.confidence, + [a.evidence] + ) + + if "regular_intervals" in by_type: + for a in by_type["regular_intervals"]: + ghost.add_operation( + "SCHEDULED_JOB", + f"Regular processing schedule detected: {a.details.get('interval_desc', 'unknown')}", + a.confidence, + [a.evidence] + ) + + if "temporal_clustering" in by_type: + ghost.add_operation( + "BURST_PROCESSING", + "Event-driven or burst batch processing detected", + 0.75, + [a.evidence for a in by_type["temporal_clustering"]] + ) + + # 7. Data quality issues + if "encoding_artifact" in by_type: + for a in by_type["encoding_artifact"]: + ghost.add_operation( + "ENCODING_ERROR", + f"Character encoding conversion failed: {a.evidence}", + a.confidence, + [a.evidence] + ) + + if "sentinel_value" in by_type: + for a in by_type["sentinel_value"]: + ghost.add_operation( + "NULL_HANDLING", + f"NULLs represented as sentinel value {a.details.get('sentinel', '?')}", + a.confidence, + [a.evidence] + ) + + if "high_null_rate" in by_type: + for a in by_type["high_null_rate"]: + ghost.add_operation( + "OPTIONAL_FIELD", + f"Column {a.column} is optional or had ETL issues ({a.details.get('null_rate', 0)*100:.0f}% null)", + a.confidence, + [a.evidence] + ) + + # 8. Export (often the last step) + if any("PANDAS" in a.inferred_operation for a in artifacts): + ghost.add_operation( + "DATA_EXPORT", + "Data exported via Pandas to CSV", + 0.90, + ["Unnamed column artifact"] + ) + + return ghost + + def analyze_file(self, filepath: str) -> ForensicsReport: + """ + Analyze a data file. + + Supports: CSV, JSON, JSONL, Parquet, Excel + """ + import pandas as pd + from pathlib import Path + + path = Path(filepath) + suffix = path.suffix.lower() + + if suffix == '.csv': + df = pd.read_csv(filepath) + elif suffix == '.json': + df = pd.read_json(filepath) + elif suffix == '.jsonl': + df = pd.read_json(filepath, lines=True) + elif suffix == '.parquet': + df = pd.read_parquet(filepath) + elif suffix in ['.xlsx', '.xls']: + df = pd.read_excel(filepath) + else: + # Try CSV as default + df = pd.read_csv(filepath) + + return self.analyze(df) + + +def analyze_dataframe(df) -> ForensicsReport: + """Convenience function to analyze a dataframe.""" + forensics = DataForensics() + return forensics.analyze(df) + + +def analyze_file(filepath: str) -> ForensicsReport: + """Convenience function to analyze a file.""" + forensics = DataForensics() + return forensics.analyze_file(filepath) diff --git a/cascade/forensics/artifacts.py b/cascade/forensics/artifacts.py new file mode 100644 index 0000000000000000000000000000000000000000..cb065bec8c5df11809242682ae1c4fc63036c515 --- /dev/null +++ b/cascade/forensics/artifacts.py @@ -0,0 +1,1063 @@ +""" +CASCADE Forensics - Artifact Detectors + +Each detector looks for specific patterns in data that reveal +how it was processed. The data remembers. We read. +""" + +import re +import hashlib +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional, Set, Tuple +from datetime import datetime +from collections import Counter +import statistics + + +@dataclass +class Artifact: + """A single detected artifact - evidence of processing.""" + artifact_type: str + column: str + evidence: str + confidence: float # 0.0 to 1.0 + inferred_operation: str + details: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "type": self.artifact_type, + "column": self.column, + "evidence": self.evidence, + "confidence": self.confidence, + "inferred_op": self.inferred_operation, + "details": self.details, + } + + +class ArtifactDetector: + """Base class for artifact detection.""" + + name: str = "base" + + def detect(self, df, column: str) -> List[Artifact]: + """Detect artifacts in a column. Override in subclasses.""" + return [] + + def detect_all(self, df) -> List[Artifact]: + """Detect artifacts across all applicable columns.""" + artifacts = [] + for col in df.columns: + artifacts.extend(self.detect(df, col)) + return artifacts + + +class TimestampArtifacts(ArtifactDetector): + """ + Detect timestamp patterns that reveal processing behavior. + + Artifacts detected: + - Rounding to minute/hour/day (batch processing intervals) + - Regular intervals (scheduled jobs) + - Temporal clustering (burst processing) + - Timezone artifacts + - Future/past anomalies + """ + + name = "timestamp" + + def detect(self, df, column: str) -> List[Artifact]: + artifacts = [] + + # Check if column looks like timestamps + if not self._is_timestamp_column(df, column): + return artifacts + + try: + timestamps = self._parse_timestamps(df, column) + if len(timestamps) < 2: + return artifacts + + # Check for rounding patterns + rounding = self._detect_rounding(timestamps) + if rounding: + artifacts.append(rounding) + + # Check for regular intervals + intervals = self._detect_intervals(timestamps) + if intervals: + artifacts.append(intervals) + + # Check for clustering + clustering = self._detect_clustering(timestamps) + if clustering: + artifacts.append(clustering) + + # Check for timezone issues + tz_artifacts = self._detect_timezone_artifacts(timestamps) + artifacts.extend(tz_artifacts) + + except Exception: + pass + + return artifacts + + def _is_timestamp_column(self, df, column: str) -> bool: + """Heuristic to detect timestamp columns.""" + col_lower = column.lower() + timestamp_hints = ['time', 'date', 'created', 'updated', 'modified', 'timestamp', '_at', '_on'] + if any(hint in col_lower for hint in timestamp_hints): + return True + + # Check data type + dtype = str(df[column].dtype) + if 'datetime' in dtype or 'time' in dtype: + return True + + # Sample and check format + sample = df[column].dropna().head(5).astype(str).tolist() + date_patterns = [ + r'\d{4}-\d{2}-\d{2}', + r'\d{2}/\d{2}/\d{4}', + r'\d{10,13}', # Unix timestamp + ] + for val in sample: + for pattern in date_patterns: + if re.search(pattern, val): + return True + + return False + + def _parse_timestamps(self, df, column: str) -> List[datetime]: + """Parse column to datetime objects.""" + import pandas as pd + + try: + # Try pandas datetime conversion + parsed = pd.to_datetime(df[column], errors='coerce') + return [ts.to_pydatetime() for ts in parsed.dropna()] + except: + return [] + + def _detect_rounding(self, timestamps: List[datetime]) -> Optional[Artifact]: + """Detect if timestamps are rounded to specific intervals.""" + if len(timestamps) < 10: + return None + + # Check seconds + seconds = [ts.second for ts in timestamps] + unique_seconds = set(seconds) + + # All zeros = minute rounding + if unique_seconds == {0}: + # Check minutes + minutes = [ts.minute for ts in timestamps] + unique_minutes = set(minutes) + + if unique_minutes == {0}: + return Artifact( + artifact_type="timestamp_rounding", + column="timestamps", + evidence=f"All timestamps rounded to hour (0 minutes, 0 seconds)", + confidence=0.95, + inferred_operation="BATCH_HOURLY", + details={"interval": "hour", "sample_size": len(timestamps)} + ) + elif all(m % 15 == 0 for m in minutes): + return Artifact( + artifact_type="timestamp_rounding", + column="timestamps", + evidence=f"Timestamps rounded to 15-minute intervals", + confidence=0.90, + inferred_operation="BATCH_15MIN", + details={"interval": "15min", "unique_minutes": list(unique_minutes)} + ) + elif all(m % 5 == 0 for m in minutes): + return Artifact( + artifact_type="timestamp_rounding", + column="timestamps", + evidence=f"Timestamps rounded to 5-minute intervals", + confidence=0.85, + inferred_operation="BATCH_5MIN", + details={"interval": "5min"} + ) + else: + return Artifact( + artifact_type="timestamp_rounding", + column="timestamps", + evidence=f"Timestamps rounded to minute (0 seconds)", + confidence=0.85, + inferred_operation="BATCH_MINUTE", + details={"interval": "minute"} + ) + + # Check if seconds cluster on specific values + second_counts = Counter(seconds) + most_common = second_counts.most_common(1)[0] + if most_common[1] > len(timestamps) * 0.8: + return Artifact( + artifact_type="timestamp_rounding", + column="timestamps", + evidence=f"{most_common[1]/len(timestamps)*100:.0f}% of timestamps have second={most_common[0]}", + confidence=0.70, + inferred_operation="SYSTEMATIC_TIMESTAMP_ASSIGNMENT", + details={"dominant_second": most_common[0], "percentage": most_common[1]/len(timestamps)} + ) + + return None + + def _detect_intervals(self, timestamps: List[datetime]) -> Optional[Artifact]: + """Detect regular time intervals suggesting scheduled jobs.""" + if len(timestamps) < 10: + return None + + sorted_ts = sorted(timestamps) + deltas = [(sorted_ts[i+1] - sorted_ts[i]).total_seconds() for i in range(len(sorted_ts)-1)] + + if not deltas: + return None + + # Check for consistent intervals + median_delta = statistics.median(deltas) + if median_delta == 0: + return None + + # Count how many deltas are close to median + tolerance = median_delta * 0.1 # 10% tolerance + consistent = sum(1 for d in deltas if abs(d - median_delta) < tolerance) + consistency_ratio = consistent / len(deltas) + + if consistency_ratio > 0.7: + # Describe the interval + interval_desc = self._describe_interval(median_delta) + return Artifact( + artifact_type="regular_intervals", + column="timestamps", + evidence=f"{consistency_ratio*100:.0f}% of records have ~{interval_desc} intervals", + confidence=min(0.95, consistency_ratio), + inferred_operation=f"SCHEDULED_JOB_{interval_desc.upper().replace(' ', '_')}", + details={ + "median_seconds": median_delta, + "interval_desc": interval_desc, + "consistency": consistency_ratio + } + ) + + return None + + def _describe_interval(self, seconds: float) -> str: + """Human-readable interval description.""" + if seconds < 60: + return f"{seconds:.0f}s" + elif seconds < 3600: + return f"{seconds/60:.0f}min" + elif seconds < 86400: + return f"{seconds/3600:.1f}hr" + else: + return f"{seconds/86400:.1f}day" + + def _detect_clustering(self, timestamps: List[datetime]) -> Optional[Artifact]: + """Detect temporal clustering (burst processing).""" + if len(timestamps) < 20: + return None + + sorted_ts = sorted(timestamps) + + # Look for bursts: many records in short time, then gaps + deltas = [(sorted_ts[i+1] - sorted_ts[i]).total_seconds() for i in range(len(sorted_ts)-1)] + + if not deltas: + return None + + median_delta = statistics.median(deltas) + if median_delta == 0: + return None + + # Count "burst" deltas (much smaller than median) vs "gap" deltas (much larger) + bursts = sum(1 for d in deltas if d < median_delta * 0.1) + gaps = sum(1 for d in deltas if d > median_delta * 5) + + if bursts > len(deltas) * 0.3 and gaps > len(deltas) * 0.05: + return Artifact( + artifact_type="temporal_clustering", + column="timestamps", + evidence=f"Burst pattern: {bursts} rapid records, {gaps} long gaps", + confidence=0.75, + inferred_operation="BATCH_BURST_PROCESSING", + details={ + "burst_count": bursts, + "gap_count": gaps, + "median_delta_seconds": median_delta + } + ) + + return None + + def _detect_timezone_artifacts(self, timestamps: List[datetime]) -> List[Artifact]: + """Detect timezone-related artifacts.""" + artifacts = [] + + # Check for hour distribution anomalies (e.g., no records 0-7 UTC = US business hours) + hours = [ts.hour for ts in timestamps] + hour_counts = Counter(hours) + + # Check for gaps suggesting business hours in a specific timezone + zero_hours = [h for h in range(24) if hour_counts.get(h, 0) == 0] + + if len(zero_hours) >= 6 and len(zero_hours) <= 12: + # Contiguous gap? + zero_hours_sorted = sorted(zero_hours) + if zero_hours_sorted[-1] - zero_hours_sorted[0] == len(zero_hours) - 1: + artifacts.append(Artifact( + artifact_type="business_hours", + column="timestamps", + evidence=f"No records during hours {min(zero_hours)}-{max(zero_hours)} UTC", + confidence=0.70, + inferred_operation="BUSINESS_HOURS_ONLY", + details={"quiet_hours": zero_hours} + )) + + return artifacts + + +class IDPatternArtifacts(ArtifactDetector): + """ + Detect ID patterns that reveal data lineage. + + Artifacts detected: + - Sequential IDs with gaps (deletions/filtering) + - UUID versions (generation method) + - Prefixes (source identification) + - Hash patterns (deterministic generation) + """ + + name = "id_patterns" + + def detect(self, df, column: str) -> List[Artifact]: + artifacts = [] + + if not self._is_id_column(df, column): + return artifacts + + try: + values = df[column].dropna().astype(str).tolist() + if len(values) < 5: + return artifacts + + # Check for sequential integers with gaps + gaps = self._detect_sequential_gaps(values) + if gaps: + artifacts.append(gaps) + + # Check for UUID patterns + uuid_artifact = self._detect_uuid_patterns(values) + if uuid_artifact: + artifacts.append(uuid_artifact) + + # Check for prefixes + prefix = self._detect_prefixes(values) + if prefix: + artifacts.append(prefix) + + # Check for hash patterns + hash_artifact = self._detect_hash_patterns(values) + if hash_artifact: + artifacts.append(hash_artifact) + + except Exception: + pass + + return artifacts + + def _is_id_column(self, df, column: str) -> bool: + """Heuristic to detect ID columns.""" + col_lower = column.lower() + id_hints = ['id', 'key', 'uuid', 'guid', 'pk', '_id', 'identifier'] + return any(hint in col_lower for hint in id_hints) + + def _detect_sequential_gaps(self, values: List[str]) -> Optional[Artifact]: + """Detect sequential IDs with gaps indicating deletions.""" + # Try to parse as integers + try: + ints = sorted([int(v) for v in values if v.isdigit()]) + if len(ints) < 10: + return None + + # Check for gaps + expected_count = ints[-1] - ints[0] + 1 + actual_count = len(set(ints)) + gap_count = expected_count - actual_count + gap_ratio = gap_count / expected_count if expected_count > 0 else 0 + + if gap_ratio > 0.05: # More than 5% missing + return Artifact( + artifact_type="sequential_id_gaps", + column=values[0] if values else "id", + evidence=f"Sequential IDs with {gap_ratio*100:.1f}% gaps ({gap_count} missing)", + confidence=0.85, + inferred_operation="FILTERING_OR_DELETION", + details={ + "min_id": ints[0], + "max_id": ints[-1], + "expected": expected_count, + "actual": actual_count, + "gap_ratio": gap_ratio + } + ) + except: + pass + + return None + + def _detect_uuid_patterns(self, values: List[str]) -> Optional[Artifact]: + """Detect UUID version from patterns.""" + uuid_pattern = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-([0-9a-f])[0-9a-f]{3}-[0-9a-f]{4}-[0-9a-f]{12}$', re.I) + + versions = [] + for v in values[:100]: # Sample + match = uuid_pattern.match(v) + if match: + versions.append(match.group(1)) + + if len(versions) < len(values[:100]) * 0.5: + return None + + version_counts = Counter(versions) + dominant = version_counts.most_common(1)[0] + + version_meanings = { + '1': 'TIME_BASED_MAC', # Reveals generation time + machine + '2': 'DCE_SECURITY', + '3': 'MD5_HASH', # Deterministic from input + '4': 'RANDOM', # Crypto random + '5': 'SHA1_HASH', # Deterministic from input + '6': 'SORTABLE_TIME', # Modern time-sortable + '7': 'UNIX_TIME_RANDOM', # Time-ordered with randomness + } + + return Artifact( + artifact_type="uuid_version", + column="id", + evidence=f"UUIDs are version {dominant[0]} ({version_meanings.get(dominant[0], 'UNKNOWN')})", + confidence=0.90, + inferred_operation=f"UUID_GENERATION_V{dominant[0]}", + details={ + "version": dominant[0], + "meaning": version_meanings.get(dominant[0], 'unknown'), + "sample_count": len(versions) + } + ) + + def _detect_prefixes(self, values: List[str]) -> Optional[Artifact]: + """Detect common prefixes indicating source systems.""" + if len(values) < 10: + return None + + # Find common prefix + prefix_len = 0 + for i in range(1, min(20, min(len(v) for v in values[:100]))): + prefixes = set(v[:i] for v in values[:100]) + if len(prefixes) <= 3: # Allow up to 3 different prefixes + prefix_len = i + else: + break + + if prefix_len >= 2: + prefixes = Counter(v[:prefix_len] for v in values) + top_prefixes = prefixes.most_common(3) + + return Artifact( + artifact_type="id_prefix", + column="id", + evidence=f"IDs have systematic prefix: {top_prefixes}", + confidence=0.80, + inferred_operation="MULTI_SOURCE_MERGE" if len(top_prefixes) > 1 else "SOURCE_IDENTIFICATION", + details={ + "prefixes": dict(top_prefixes), + "prefix_length": prefix_len + } + ) + + return None + + def _detect_hash_patterns(self, values: List[str]) -> Optional[Artifact]: + """Detect if IDs look like hashes.""" + hex_pattern = re.compile(r'^[0-9a-f]+$', re.I) + + hex_lengths = [] + for v in values[:100]: + if hex_pattern.match(v): + hex_lengths.append(len(v)) + + if len(hex_lengths) < len(values[:100]) * 0.8: + return None + + # Check for consistent hash lengths + length_counts = Counter(hex_lengths) + dominant = length_counts.most_common(1)[0] + + hash_types = { + 32: 'MD5', + 40: 'SHA1', + 64: 'SHA256', + 128: 'SHA512', + 16: 'SHORT_HASH', + } + + if dominant[1] > len(hex_lengths) * 0.9: + hash_type = hash_types.get(dominant[0], f'{dominant[0]}-char hash') + return Artifact( + artifact_type="hash_id", + column="id", + evidence=f"IDs are {hash_type} hashes ({dominant[0]} hex chars)", + confidence=0.85, + inferred_operation=f"DETERMINISTIC_ID_GENERATION_{hash_type}", + details={ + "hash_length": dominant[0], + "probable_algorithm": hash_type + } + ) + + return None + + +class TextArtifacts(ArtifactDetector): + """ + Detect text processing artifacts. + + Artifacts detected: + - Truncation (field length limits) + - Encoding issues (charset conversion) + - Case normalization + - Whitespace patterns + - Sanitization patterns + """ + + name = "text" + + def detect(self, df, column: str) -> List[Artifact]: + artifacts = [] + + dtype = str(df[column].dtype) + if 'object' not in dtype and 'str' not in dtype: + return artifacts + + try: + values = df[column].dropna().astype(str).tolist() + if len(values) < 5: + return artifacts + + # Truncation + trunc = self._detect_truncation(values) + if trunc: + artifacts.append(trunc) + + # Encoding issues + encoding = self._detect_encoding_artifacts(values) + if encoding: + artifacts.append(encoding) + + # Case patterns + case = self._detect_case_patterns(values, column) + if case: + artifacts.append(case) + + # Whitespace + ws = self._detect_whitespace_patterns(values) + if ws: + artifacts.append(ws) + + except Exception: + pass + + return artifacts + + def _detect_truncation(self, values: List[str]) -> Optional[Artifact]: + """Detect truncation at specific lengths.""" + lengths = [len(v) for v in values] + max_len = max(lengths) + + # Count values at max length + at_max = sum(1 for l in lengths if l == max_len) + + # If many values hit the max, likely truncation + if at_max > len(values) * 0.1 and max_len > 10: + # Check if values at max look truncated (end mid-word, etc.) + max_values = [v for v in values if len(v) == max_len] + truncated_looking = sum(1 for v in max_values if not v.endswith(('.', '!', '?', ' '))) + + if truncated_looking > len(max_values) * 0.5: + return Artifact( + artifact_type="truncation", + column=str(values[0])[:20] if values else "text", + evidence=f"{at_max} values ({at_max/len(values)*100:.1f}%) truncated at {max_len} chars", + confidence=0.80, + inferred_operation=f"FIELD_LENGTH_LIMIT_{max_len}", + details={ + "max_length": max_len, + "truncated_count": at_max, + "truncated_ratio": at_max / len(values) + } + ) + + return None + + def _detect_encoding_artifacts(self, values: List[str]) -> Optional[Artifact]: + """Detect encoding/charset conversion issues.""" + # Common mojibake patterns + mojibake_patterns = [ + r'é', # é misencoded + r'è', # è + r'à ', # à + r'’', # ' smart quote + r'â€"', # — em dash + r'ö', # ö + r'ü', # ü + r'', # BOM + r'\\x[0-9a-f]{2}', # Raw hex escapes + r'&|<|>', # HTML entities + ] + + issue_count = 0 + patterns_found = set() + + for v in values[:500]: # Sample + for pattern in mojibake_patterns: + if re.search(pattern, v): + issue_count += 1 + patterns_found.add(pattern) + break + + if issue_count > 5: + return Artifact( + artifact_type="encoding_artifact", + column="text", + evidence=f"{issue_count} values have encoding issues (patterns: {patterns_found})", + confidence=0.85, + inferred_operation="CHARSET_CONVERSION_ERROR", + details={ + "issue_count": issue_count, + "patterns": list(patterns_found) + } + ) + + return None + + def _detect_case_patterns(self, values: List[str], column: str) -> Optional[Artifact]: + """Detect case normalization.""" + # Skip obviously non-text columns + sample = values[:100] + + all_lower = all(v == v.lower() for v in sample if v.strip()) + all_upper = all(v == v.upper() for v in sample if v.strip()) + + if all_lower: + return Artifact( + artifact_type="case_normalization", + column=column, + evidence="All values are lowercase", + confidence=0.90, + inferred_operation="LOWERCASE_NORMALIZATION", + details={"case": "lower"} + ) + elif all_upper: + return Artifact( + artifact_type="case_normalization", + column=column, + evidence="All values are UPPERCASE", + confidence=0.90, + inferred_operation="UPPERCASE_NORMALIZATION", + details={"case": "upper"} + ) + + return None + + def _detect_whitespace_patterns(self, values: List[str]) -> Optional[Artifact]: + """Detect whitespace handling patterns.""" + # Check for leading/trailing whitespace + has_leading = sum(1 for v in values if v and v[0] == ' ') + has_trailing = sum(1 for v in values if v and v[-1] == ' ') + + # No whitespace at all = trimmed + if has_leading == 0 and has_trailing == 0: + # Verify there's text that COULD have whitespace + has_spaces = sum(1 for v in values if ' ' in v.strip()) + if has_spaces > len(values) * 0.3: + return Artifact( + artifact_type="whitespace_trimming", + column="text", + evidence="No leading/trailing whitespace (data was trimmed)", + confidence=0.70, + inferred_operation="WHITESPACE_TRIM", + details={"trimmed": True} + ) + + return None + + +class NumericArtifacts(ArtifactDetector): + """ + Detect numeric processing artifacts. + + Artifacts detected: + - Rounding patterns (precision limits) + - Outlier presence/absence (filtering) + - Distribution anomalies (sampling) + - Sentinel values (nulls represented as -1, 0, 9999) + """ + + name = "numeric" + + def detect(self, df, column: str) -> List[Artifact]: + artifacts = [] + + # Check if numeric + try: + values = df[column].dropna() + if len(values) < 10: + return artifacts + + # Try to get numeric values + numeric_values = values.astype(float).tolist() + + # Rounding + rounding = self._detect_rounding(numeric_values, column) + if rounding: + artifacts.append(rounding) + + # Sentinel values + sentinel = self._detect_sentinel_values(numeric_values, column) + if sentinel: + artifacts.append(sentinel) + + # Distribution + dist = self._detect_distribution_artifacts(numeric_values, column) + if dist: + artifacts.append(dist) + + except (ValueError, TypeError): + pass + + return artifacts + + def _detect_rounding(self, values: List[float], column: str) -> Optional[Artifact]: + """Detect systematic rounding.""" + # Check decimal places + decimal_places = [] + for v in values[:500]: + if v != int(v): + str_v = f"{v:.10f}".rstrip('0') + if '.' in str_v: + decimal_places.append(len(str_v.split('.')[1])) + + if not decimal_places: + # All integers - check for rounding to 10, 100, etc. + int_values = [int(v) for v in values] + + divisible_by_100 = sum(1 for v in int_values if v % 100 == 0) + divisible_by_10 = sum(1 for v in int_values if v % 10 == 0) + + if divisible_by_100 > len(int_values) * 0.9: + return Artifact( + artifact_type="numeric_rounding", + column=column, + evidence="Values rounded to nearest 100", + confidence=0.85, + inferred_operation="ROUND_TO_100", + details={"rounding": 100} + ) + elif divisible_by_10 > len(int_values) * 0.9: + return Artifact( + artifact_type="numeric_rounding", + column=column, + evidence="Values rounded to nearest 10", + confidence=0.80, + inferred_operation="ROUND_TO_10", + details={"rounding": 10} + ) + else: + # Check for consistent decimal places + max_decimals = max(decimal_places) + at_max = sum(1 for d in decimal_places if d == max_decimals) + + if at_max < len(decimal_places) * 0.3 and max_decimals <= 2: + return Artifact( + artifact_type="numeric_rounding", + column=column, + evidence=f"Values appear rounded to {max_decimals} decimal places", + confidence=0.75, + inferred_operation=f"ROUND_TO_{max_decimals}_DECIMALS", + details={"decimal_places": max_decimals} + ) + + return None + + def _detect_sentinel_values(self, values: List[float], column: str) -> Optional[Artifact]: + """Detect sentinel values representing nulls.""" + sentinels = [-1, -999, -9999, 0, 9999, 99999] + + value_counts = Counter(values) + + for sentinel in sentinels: + if sentinel in value_counts: + count = value_counts[sentinel] + if count > len(values) * 0.01: # More than 1% + return Artifact( + artifact_type="sentinel_value", + column=column, + evidence=f"{count} occurrences of {sentinel} (likely NULL sentinel)", + confidence=0.70, + inferred_operation=f"NULL_AS_{int(sentinel)}", + details={ + "sentinel": sentinel, + "count": count, + "percentage": count / len(values) * 100 + } + ) + + return None + + def _detect_distribution_artifacts(self, values: List[float], column: str) -> Optional[Artifact]: + """Detect distribution anomalies suggesting filtering/sampling.""" + if len(values) < 100: + return None + + # Check for hard cutoffs + sorted_vals = sorted(values) + min_val, max_val = sorted_vals[0], sorted_vals[-1] + + # Round number cutoffs suggest filtering + if max_val == int(max_val) and max_val % 10 == 0: + # Check if there's a cluster at the max + at_max = sum(1 for v in values if v == max_val) + if at_max > len(values) * 0.05: + return Artifact( + artifact_type="hard_cutoff", + column=column, + evidence=f"Hard cutoff at {max_val} ({at_max} values at limit)", + confidence=0.75, + inferred_operation=f"CAP_AT_{int(max_val)}", + details={ + "cutoff": max_val, + "count_at_cutoff": at_max + } + ) + + return None + + +class NullPatternArtifacts(ArtifactDetector): + """ + Detect null/missing value patterns. + + Artifacts detected: + - Systematic nulls (default handling) + - Null correlations (conditional logic) + - Null rates anomalies (ETL errors) + """ + + name = "null_patterns" + + def detect_all(self, df) -> List[Artifact]: + """Analyze null patterns across all columns.""" + artifacts = [] + + # Overall null rates per column + null_rates = {} + for col in df.columns: + null_rate = df[col].isna().mean() + null_rates[col] = null_rate + + # Detect anomalous null rates + rates = list(null_rates.values()) + if len(rates) > 3: + mean_rate = statistics.mean(rates) + + for col, rate in null_rates.items(): + if rate > 0.5 and rate > mean_rate * 3: + artifacts.append(Artifact( + artifact_type="high_null_rate", + column=col, + evidence=f"{rate*100:.1f}% null (vs {mean_rate*100:.1f}% average)", + confidence=0.70, + inferred_operation="OPTIONAL_FIELD_OR_ETL_ERROR", + details={ + "null_rate": rate, + "avg_null_rate": mean_rate + } + )) + + # Detect columns that are null together (conditional logic) + # This is expensive so we sample + if len(df) > 100: + sample = df.sample(min(1000, len(df))) + else: + sample = df + + correlated_nulls = [] + cols = list(df.columns) + for i, col1 in enumerate(cols): + for col2 in cols[i+1:]: + both_null = (sample[col1].isna() & sample[col2].isna()).mean() + either_null = (sample[col1].isna() | sample[col2].isna()).mean() + + if either_null > 0.1 and both_null / either_null > 0.8: + correlated_nulls.append((col1, col2, both_null)) + + if correlated_nulls: + artifacts.append(Artifact( + artifact_type="correlated_nulls", + column="multiple", + evidence=f"{len(correlated_nulls)} column pairs have correlated nulls", + confidence=0.75, + inferred_operation="CONDITIONAL_FIELD_POPULATION", + details={ + "pairs": [(c1, c2) for c1, c2, _ in correlated_nulls[:5]] + } + )) + + return artifacts + + def detect(self, df, column: str) -> List[Artifact]: + """Null patterns are analyzed globally, not per-column.""" + return [] + + +class SchemaArtifacts(ArtifactDetector): + """ + Detect schema-level artifacts. + + Artifacts detected: + - Column naming conventions (framework hints) + - Data type patterns (database origin) + - Schema inconsistencies (merged sources) + """ + + name = "schema" + + def detect_all(self, df) -> List[Artifact]: + """Analyze schema patterns.""" + artifacts = [] + + columns = list(df.columns) + + # Naming convention detection + conventions = self._detect_naming_conventions(columns) + if conventions: + artifacts.append(conventions) + + # Framework fingerprints + framework = self._detect_framework_fingerprints(columns) + if framework: + artifacts.append(framework) + + # Mixed conventions (merged sources) + mixed = self._detect_mixed_conventions(columns) + if mixed: + artifacts.append(mixed) + + return artifacts + + def detect(self, df, column: str) -> List[Artifact]: + """Schema patterns are analyzed globally.""" + return [] + + def _detect_naming_conventions(self, columns: List[str]) -> Optional[Artifact]: + """Detect column naming convention.""" + snake_case = sum(1 for c in columns if '_' in c and c == c.lower()) + camel_case = sum(1 for c in columns if re.match(r'^[a-z]+([A-Z][a-z]+)+$', c)) + pascal_case = sum(1 for c in columns if re.match(r'^([A-Z][a-z]+)+$', c)) + + total = len(columns) + + if snake_case > total * 0.7: + return Artifact( + artifact_type="naming_convention", + column="schema", + evidence=f"snake_case naming ({snake_case}/{total} columns)", + confidence=0.80, + inferred_operation="PYTHON_OR_SQL_ORIGIN", + details={"convention": "snake_case", "ratio": snake_case/total} + ) + elif camel_case > total * 0.5: + return Artifact( + artifact_type="naming_convention", + column="schema", + evidence=f"camelCase naming ({camel_case}/{total} columns)", + confidence=0.80, + inferred_operation="JAVASCRIPT_OR_JAVA_ORIGIN", + details={"convention": "camelCase", "ratio": camel_case/total} + ) + elif pascal_case > total * 0.5: + return Artifact( + artifact_type="naming_convention", + column="schema", + evidence=f"PascalCase naming ({pascal_case}/{total} columns)", + confidence=0.80, + inferred_operation="DOTNET_OR_JAVA_ORIGIN", + details={"convention": "PascalCase", "ratio": pascal_case/total} + ) + + return None + + def _detect_framework_fingerprints(self, columns: List[str]) -> Optional[Artifact]: + """Detect framework-specific column patterns.""" + col_lower = [c.lower() for c in columns] + + # Django fingerprints + if 'id' in col_lower and 'created_at' in col_lower: + return Artifact( + artifact_type="framework_fingerprint", + column="schema", + evidence="Django/Rails-style auto columns (id, created_at)", + confidence=0.65, + inferred_operation="ORM_GENERATED_SCHEMA", + details={"framework_hints": ["django", "rails", "sqlalchemy"]} + ) + + # Pandas export fingerprints + if 'unnamed: 0' in col_lower or any('unnamed:' in c for c in col_lower): + return Artifact( + artifact_type="framework_fingerprint", + column="schema", + evidence="Pandas index column artifact (Unnamed: 0)", + confidence=0.90, + inferred_operation="PANDAS_CSV_EXPORT", + details={"framework": "pandas"} + ) + + # MongoDB fingerprints + if '_id' in col_lower: + return Artifact( + artifact_type="framework_fingerprint", + column="schema", + evidence="MongoDB _id column present", + confidence=0.85, + inferred_operation="MONGODB_EXPORT", + details={"framework": "mongodb"} + ) + + return None + + def _detect_mixed_conventions(self, columns: List[str]) -> Optional[Artifact]: + """Detect mixed naming conventions suggesting merged sources.""" + snake_case = sum(1 for c in columns if '_' in c and c == c.lower()) + camel_case = sum(1 for c in columns if re.match(r'^[a-z]+([A-Z][a-z]+)+$', c)) + + total = len(columns) + + # Both conventions present significantly + if snake_case > total * 0.2 and camel_case > total * 0.2: + return Artifact( + artifact_type="mixed_conventions", + column="schema", + evidence=f"Mixed naming: {snake_case} snake_case, {camel_case} camelCase", + confidence=0.75, + inferred_operation="MERGED_SOURCES", + details={ + "snake_case_count": snake_case, + "camel_case_count": camel_case + } + ) + + return None diff --git a/cascade/forensics/fingerprints.py b/cascade/forensics/fingerprints.py new file mode 100644 index 0000000000000000000000000000000000000000..1f390c2bd38f025a57418b10fb195d617d0917bc --- /dev/null +++ b/cascade/forensics/fingerprints.py @@ -0,0 +1,328 @@ +""" +CASCADE Forensics - Technology Fingerprinting + +Map detected artifacts to likely technologies and tools. +The artifacts are evidence. This module is the detective. +""" + +from dataclasses import dataclass, field +from typing import List, Dict, Any, Set +from collections import defaultdict + + +@dataclass +class Fingerprint: + """A technology fingerprint - evidence pointing to specific tools.""" + technology: str + category: str # database, framework, language, tool + confidence: float + evidence: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "technology": self.technology, + "category": self.category, + "confidence": self.confidence, + "evidence": self.evidence, + } + + +class TechFingerprinter: + """ + Map artifact patterns to likely technologies. + + This is pattern matching - certain artifact combinations + are strong indicators of specific tools. + """ + + # Artifact patterns -> technology mappings + PATTERNS = { + # Databases + "MONGODB_EXPORT": { + "technology": "MongoDB", + "category": "database", + "weight": 0.9, + }, + "ORM_GENERATED_SCHEMA": { + "technology": "ORM (Django/Rails/SQLAlchemy)", + "category": "framework", + "weight": 0.7, + }, + "PANDAS_CSV_EXPORT": { + "technology": "Pandas", + "category": "tool", + "weight": 0.95, + }, + + # Processing tools + "LOWERCASE_NORMALIZATION": { + "technology": "Text Preprocessing", + "category": "processing", + "weight": 0.6, + }, + "WHITESPACE_TRIM": { + "technology": "String Cleaning", + "category": "processing", + "weight": 0.5, + }, + + # Batch processing + "BATCH_HOURLY": { + "technology": "Scheduled Batch Job (hourly)", + "category": "infrastructure", + "weight": 0.8, + }, + "BATCH_15MIN": { + "technology": "Scheduled Batch Job (15min)", + "category": "infrastructure", + "weight": 0.8, + }, + "BATCH_BURST_PROCESSING": { + "technology": "Event-Driven Batch Processing", + "category": "infrastructure", + "weight": 0.7, + }, + "SCHEDULED_JOB": { + "technology": "Cron/Scheduler", + "category": "infrastructure", + "weight": 0.75, + }, + + # ID generation + "UUID_GENERATION_V4": { + "technology": "Cryptographic UUID Generator", + "category": "tool", + "weight": 0.8, + }, + "UUID_GENERATION_V1": { + "technology": "Time-based UUID (leaks timestamp + MAC)", + "category": "tool", + "weight": 0.85, + }, + "DETERMINISTIC_ID_GENERATION_SHA256": { + "technology": "Content-Addressed Storage", + "category": "architecture", + "weight": 0.8, + }, + "DETERMINISTIC_ID_GENERATION_MD5": { + "technology": "MD5 Hash IDs (legacy system)", + "category": "architecture", + "weight": 0.8, + }, + + # Data quality + "FILTERING_OR_DELETION": { + "technology": "Record Filtering/Deletion Pipeline", + "category": "processing", + "weight": 0.7, + }, + "CHARSET_CONVERSION_ERROR": { + "technology": "Encoding Mismatch (Latin-1 vs UTF-8)", + "category": "bug", + "weight": 0.85, + }, + + # Languages/frameworks + "PYTHON_OR_SQL_ORIGIN": { + "technology": "Python or SQL", + "category": "language", + "weight": 0.6, + }, + "JAVASCRIPT_OR_JAVA_ORIGIN": { + "technology": "JavaScript or Java", + "category": "language", + "weight": 0.6, + }, + + # Source merging + "MERGED_SOURCES": { + "technology": "Multi-Source Data Integration", + "category": "architecture", + "weight": 0.8, + }, + "MULTI_SOURCE_MERGE": { + "technology": "Multi-Source Data Integration", + "category": "architecture", + "weight": 0.85, + }, + } + + # Compound patterns - combinations that strengthen identification + COMPOUND_PATTERNS = [ + { + "requires": ["PANDAS_CSV_EXPORT", "PYTHON_OR_SQL_ORIGIN"], + "suggests": Fingerprint("Pandas Data Pipeline", "tool", 0.95), + }, + { + "requires": ["MONGODB_EXPORT", "JAVASCRIPT_OR_JAVA_ORIGIN"], + "suggests": Fingerprint("Node.js + MongoDB Stack", "stack", 0.85), + }, + { + "requires": ["ORM_GENERATED_SCHEMA", "BATCH_HOURLY"], + "suggests": Fingerprint("Django/Rails Batch Worker", "stack", 0.80), + }, + { + "requires": ["CHARSET_CONVERSION_ERROR", "MERGED_SOURCES"], + "suggests": Fingerprint("Legacy System Migration", "context", 0.85), + }, + { + "requires": ["UUID_GENERATION_V1", "BATCH_BURST_PROCESSING"], + "suggests": Fingerprint("Distributed System (pre-2015 design)", "architecture", 0.75), + }, + ] + + def __init__(self): + self.fingerprints: List[Fingerprint] = [] + + def analyze(self, artifacts: List['Artifact']) -> List[Fingerprint]: + """ + Analyze artifacts and return technology fingerprints. + + Args: + artifacts: List of detected artifacts + + Returns: + List of technology fingerprints sorted by confidence + """ + self.fingerprints = [] + + # Get all inferred operations + operations = set(a.inferred_operation for a in artifacts) + + # Match against patterns + tech_evidence = defaultdict(list) + tech_confidence = defaultdict(float) + tech_category = {} + + for op in operations: + # Direct pattern match + if op in self.PATTERNS: + pattern = self.PATTERNS[op] + tech = pattern["technology"] + tech_evidence[tech].append(op) + tech_confidence[tech] = max(tech_confidence[tech], pattern["weight"]) + tech_category[tech] = pattern["category"] + + # Partial match (for patterns with suffixes like SCHEDULED_JOB_24HR) + for pattern_name, pattern in self.PATTERNS.items(): + if op.startswith(pattern_name.split('_')[0] + '_'): + tech = pattern["technology"] + if tech not in tech_evidence or op not in tech_evidence[tech]: + tech_evidence[tech].append(op) + tech_confidence[tech] = max(tech_confidence[tech], pattern["weight"] * 0.9) + tech_category[tech] = pattern["category"] + + # Check compound patterns + for compound in self.COMPOUND_PATTERNS: + required = set(compound["requires"]) + if required.issubset(operations): + fp = compound["suggests"] + tech_evidence[fp.technology].extend(list(required)) + tech_confidence[fp.technology] = max(tech_confidence.get(fp.technology, 0), fp.confidence) + tech_category[fp.technology] = fp.category + + # Build fingerprint objects + for tech, evidence in tech_evidence.items(): + self.fingerprints.append(Fingerprint( + technology=tech, + category=tech_category.get(tech, "unknown"), + confidence=tech_confidence[tech], + evidence=list(set(evidence)), + )) + + # Sort by confidence + self.fingerprints.sort(key=lambda f: f.confidence, reverse=True) + + return self.fingerprints + + def get_likely_stack(self) -> Dict[str, Any]: + """ + Synthesize fingerprints into a likely technology stack. + + Returns: + Dict describing the probable system architecture + """ + if not self.fingerprints: + return {"stack": "Unknown", "components": []} + + # Group by category + by_category = defaultdict(list) + for fp in self.fingerprints: + by_category[fp.category].append(fp) + + stack = { + "database": None, + "framework": None, + "language": None, + "processing": [], + "infrastructure": [], + "architecture_notes": [], + } + + # Pick highest confidence for single-value categories + for cat in ["database", "framework", "language"]: + if cat in by_category: + stack[cat] = by_category[cat][0].technology + + # Aggregate list categories + for cat in ["processing", "infrastructure"]: + if cat in by_category: + stack[cat] = [fp.technology for fp in by_category[cat]] + + # Architecture notes from high-confidence findings + if "architecture" in by_category: + stack["architecture_notes"] = [fp.technology for fp in by_category["architecture"]] + + # Bugs/issues + if "bug" in by_category: + stack["issues"] = [fp.technology for fp in by_category["bug"]] + + return stack + + def get_security_concerns(self) -> List[Dict[str, Any]]: + """ + Identify security-relevant findings. + + Returns: + List of security concerns derived from fingerprints + """ + concerns = [] + + for fp in self.fingerprints: + # UUID v1 leaks info + if "UUID" in fp.technology and "V1" in fp.technology: + concerns.append({ + "severity": "medium", + "issue": "UUID v1 leaks timestamp and MAC address", + "evidence": fp.evidence, + "recommendation": "Use UUID v4 for privacy", + }) + + # MD5 for IDs + if "MD5" in fp.technology: + concerns.append({ + "severity": "low", + "issue": "MD5 used for ID generation (collision risk)", + "evidence": fp.evidence, + "recommendation": "Consider SHA-256 for content addressing", + }) + + # Encoding errors = data loss + if "Encoding" in fp.technology or "charset" in fp.technology.lower(): + concerns.append({ + "severity": "medium", + "issue": "Character encoding errors indicate data corruption", + "evidence": fp.evidence, + "recommendation": "Audit data pipeline for charset handling", + }) + + # Legacy patterns + if "legacy" in fp.technology.lower() or "pre-2015" in fp.technology.lower(): + concerns.append({ + "severity": "info", + "issue": "Legacy system patterns detected", + "evidence": fp.evidence, + "recommendation": "Review for technical debt", + }) + + return concerns diff --git a/cascade/genesis.py b/cascade/genesis.py new file mode 100644 index 0000000000000000000000000000000000000000..82b943fe3e36e58809034a59197e56f343c060a4 --- /dev/null +++ b/cascade/genesis.py @@ -0,0 +1,200 @@ +""" +CASCADE Genesis - The origin node of the neural internetwork. + +Every chain begins here. Systems link to genesis (or to any +descendant of genesis) to join the lattice. + +The chain IS the registry. No separate discovery needed. + +Usage: + # Create genesis (done once, published to well-known location) + genesis = create_genesis() + + # Any system joins by linking to genesis + my_chain.link_external(genesis.merkle_root) + + # Or by linking to any existing node in the lattice + my_chain.link_external(some_other_chain.merkle_root) + + # The lattice grows. Discovery = reading the chain. +""" + +import hashlib +import json +import time +from pathlib import Path +from typing import Optional, Dict, Any + +from cascade.core.provenance import ProvenanceChain, ProvenanceRecord + + +# Well-known genesis identifiers +GENESIS_SESSION_ID = "genesis_0" +GENESIS_MODEL_ID = "cascade_genesis" +GENESIS_INPUT = "In the beginning was the hash, and the hash was with the chain, and the hash was the chain." + + +def create_genesis() -> ProvenanceChain: + """ + Create the genesis chain - origin of the neural internetwork. + + This is deterministic. Anyone running this gets the same genesis. + That's the point - it's the Schelling point for the lattice. + """ + # Deterministic input hash + input_hash = hashlib.sha256(GENESIS_INPUT.encode()).hexdigest()[:16] + + # Deterministic model hash (hash of the genesis concept itself) + model_hash = hashlib.sha256(b"cascade_neural_internetwork_v1").hexdigest()[:16] + + chain = ProvenanceChain( + session_id=GENESIS_SESSION_ID, + model_id=GENESIS_MODEL_ID, + model_hash=model_hash, + input_hash=input_hash, + ) + + # The genesis record - the first node + # Its parent is itself (bootstrap) + genesis_record = ProvenanceRecord( + layer_name="genesis", + layer_idx=0, + state_hash=input_hash, # Self-referential + parent_hashes=[input_hash], # Points to itself + params_hash=model_hash, + shape=[1], + dtype="genesis", + stats={"created": time.time()}, + execution_order=0, + ) + + chain.add_record(genesis_record) + chain.finalize() + + return chain + + +def get_genesis_root() -> str: + """ + Get the genesis merkle root. + + This is a constant - the Schelling point. + Any system can compute it and know they're linking to the same origin. + """ + return create_genesis().merkle_root + + +def save_genesis(path: Path) -> str: + """ + Save genesis chain to file. + + This file can be published to a well-known location + (HuggingFace dataset, IPFS, etc.) + """ + genesis = create_genesis() + + with open(path, 'w') as f: + json.dump(genesis.to_dict(), f, indent=2) + + return genesis.merkle_root + + +def load_genesis(path: Path) -> ProvenanceChain: + """Load genesis from file and verify it's authentic.""" + with open(path, 'r') as f: + data = json.load(f) + + chain = ProvenanceChain.from_dict(data) + + # Verify this is actually genesis + expected_root = get_genesis_root() + if chain.merkle_root != expected_root: + raise ValueError( + f"Invalid genesis: root {chain.merkle_root} != expected {expected_root}" + ) + + return chain + + +def link_to_genesis(chain: ProvenanceChain) -> None: + """ + Link a chain to genesis, joining the neural internetwork. + + This is the simplest way to join - link directly to the origin. + Alternatively, link to any other chain that traces back to genesis. + """ + chain.link_external(get_genesis_root(), source_id="genesis") + + +def verify_lineage_to_genesis(chain: ProvenanceChain, known_chains: Dict[str, ProvenanceChain]) -> bool: + """ + Verify that a chain traces back to genesis through external_roots. + + Args: + chain: The chain to verify + known_chains: Dict mapping merkle_root -> chain for lookup + + Returns: + True if chain traces to genesis, False otherwise + """ + genesis_root = get_genesis_root() + visited = set() + + def trace(root: str) -> bool: + if root in visited: + return False + visited.add(root) + + # Found genesis! + if root == genesis_root: + return True + + # Look up this chain + if root not in known_chains: + return False # Can't verify - chain not known + + c = known_chains[root] + + # Check if any external root leads to genesis + for ext_root in c.external_roots: + if trace(ext_root): + return True + + return False + + # Start from the chain's own root + return trace(chain.merkle_root) or any(trace(r) for r in chain.external_roots) + + +# ============================================================================= +# CLI for genesis operations +# ============================================================================= + +if __name__ == "__main__": + import sys + + genesis = create_genesis() + + print("=" * 60) + print("CASCADE GENESIS") + print("=" * 60) + print(f"Merkle Root: {genesis.merkle_root}") + print(f"Session ID: {genesis.session_id}") + print(f"Model ID: {genesis.model_id}") + print(f"Input Hash: {genesis.input_hash}") + print("=" * 60) + print() + print("This is the origin of the neural internetwork.") + print("Any system can link to this root to join the lattice.") + print() + print("To join:") + print(" from cascade.genesis import get_genesis_root") + print(" my_chain.link_external(get_genesis_root())") + print() + + # Save if requested + if len(sys.argv) > 1 and sys.argv[1] == "--save": + out_path = Path(sys.argv[2]) if len(sys.argv) > 2 else Path("genesis.json") + root = save_genesis(out_path) + print(f"Genesis saved to: {out_path}") + print(f"Root: {root}") diff --git a/cascade/hold/__init__.py b/cascade/hold/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..83f210e7108a2181682afe3c4de285ab827bc5ee --- /dev/null +++ b/cascade/hold/__init__.py @@ -0,0 +1,82 @@ +""" +╔═══════════════════════════════════════════════════════════════════════════════╗ +║ ║ +║ ██╗ ██╗ ██████╗ ██╗ ██████╗ ║ +║ ██║ ██║██╔═══██╗██║ ██╔══██╗ ║ +║ ███████║██║ ██║██║ ██║ ██║ ║ +║ ██╔══██║██║ ██║██║ ██║ ██║ ║ +║ ██║ ██║╚██████╔╝███████╗██████╔╝ ║ +║ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═════╝ ║ +║ ║ +║ Inference-Level Halt Protocol for CASCADE-LATTICE ║ +║ ║ +║ "Pause the machine. See what it sees. Choose what it chooses." ║ +║ ║ +╚═══════════════════════════════════════════════════════════════════════════════╝ + +HOLD is MODEL-AGNOSTIC. Works with ANY framework: + - PyTorch, JAX, TensorFlow, scikit-learn + - Hugging Face, OpenAI API, Anthropic API + - Stable Baselines3, RLlib, custom RL + - Any function that outputs probabilities + +USAGE: + >>> from cascade.hold import Hold + >>> + >>> # Your model (any framework) + >>> probs = your_model.predict(obs) + >>> + >>> # HOLD at decision point + >>> hold = Hold.get() + >>> resolution = hold.yield_point( + ... action_probs=probs, + ... value=value_estimate, + ... observation=obs, + ... brain_id="my_model", + ... # Optional informational wealth: + ... action_labels=["up", "down", "left", "right"], + ... latent=model.get_latent(), + ... attention=model.get_attention(), + ... features=model.get_features(), + ... imagination=model.imagine_futures(), + ... ) + >>> + >>> # Use resolved action + >>> action = resolution.action + >>> was_override = resolution.was_override + +CLI: + $ cascade hold # Start HOLD interface + $ cascade hold-status # Show HOLD system status +""" + +# Primitives - the core API +from cascade.hold.primitives import ( + HoldState, + HoldPoint, + HoldResolution, + Hold, + HoldAwareMixin, +) + +# Session Layer - arcade-style history and time travel +from cascade.hold.session import ( + InferenceStep, + HoldSession, + ArcadeFeedback, + CausationHold, +) + +__all__ = [ + # Primitives + "HoldState", + "HoldPoint", + "HoldResolution", + "Hold", + "HoldAwareMixin", + # Session + "InferenceStep", + "HoldSession", + "ArcadeFeedback", + "CausationHold", +] diff --git a/cascade/hold/primitives.py b/cascade/hold/primitives.py new file mode 100644 index 0000000000000000000000000000000000000000..d2165f4b836665481e1638c5872649b530e502d8 --- /dev/null +++ b/cascade/hold/primitives.py @@ -0,0 +1,673 @@ +""" +HOLD Primitives - Core Data Structures and Singleton +═══════════════════════════════════════════════════════════ + +The primitive layer of HOLD: +- HoldPoint: A frozen moment in inference +- HoldResolution: The outcome of a hold +- Hold: Singleton system managing inference-level halts + +HOLD is a CASCADE-LATTICE primitive. +No cascade = No HOLD. +""" + +import time +import hashlib +import threading +from typing import Dict, Any, Optional, Callable, List +from dataclasses import dataclass, field +from enum import Enum +import numpy as np + +# CASCADE-LATTICE is REQUIRED +try: + from cascade import sdk_observe + from cascade.core.event import CausationLink + from cascade.core.graph import CausationGraph + HAS_CASCADE = True +except ImportError: + HAS_CASCADE = False + # Stubs for when imported standalone (testing) + def sdk_observe(*args, **kwargs): pass + class CausationLink: + def __init__(self, **kwargs): pass + class CausationGraph: + def add_link(self, link): pass + + +class HoldState(Enum): + """State of a hold point.""" + PENDING = "pending" # Waiting for resolution + ACCEPTED = "accepted" # AI choice was accepted + OVERRIDDEN = "overridden" # Human override + TIMEOUT = "timeout" # Timed out, fell back to AI + CANCELLED = "cancelled" # Hold was cancelled + + +def _sanitize(data: Any) -> Any: + """Recursively convert numpy types to python types.""" + if isinstance(data, dict): + return {k: _sanitize(v) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + return [_sanitize(x) for x in data] + elif isinstance(data, np.generic): + return data.item() + return data + + +@dataclass +class HoldPoint: + """ + A decision point where inference yields for potential human intervention. + + This is the "freeze frame" - the moment before commitment. + The decision matrix is exposed, the merkle chain awaits. + + INFORMATIONAL WEALTH - everything a human needs to understand the decision: + - action_labels: What each action means ("FORWARD", "ATTACK", etc.) + - latent: The model's internal representation (for inspection) + - attention: What the model is attending to + - features: Extracted feature activations + - imagination: Per-action trajectory predictions and expected values + - logits: Raw logits before softmax (for temperature analysis) + - reasoning: Text explanations if available + """ + # Decision matrix + action_probs: np.ndarray # The probability distribution + value: float # Predicted value + + # Context + observation: Dict[str, Any] # What the brain saw + brain_id: str # Which brain is holding + + # === INFORMATIONAL WEALTH === + + # Action labels - CRITICAL for human understanding + action_labels: Optional[List[str]] = None # ["NOOP", "FORWARD", "BACK", ...] + + # Internal state + latent: Optional[np.ndarray] = None # Latent activations (any shape) + attention: Optional[Dict[str, float]] = None # {"position": 0.7, "health": 0.3, ...} + features: Optional[Dict[str, float]] = None # {"spatial_attn": 0.8, "danger": 0.2, ...} + + # Per-action deep data + imagination: Optional[Dict[int, Dict]] = None # {0: {"trajectory": [...], "expected_value": 0.5}, ...} + + # Logits (pre-softmax) + logits: Optional[np.ndarray] = None # Raw logits for each action + + # Reasoning chain (if model provides explanations) + reasoning: Optional[List[str]] = None # ["High reward expected", "Low risk path", ...] + + # World model predictions (if available) + world_prediction: Optional[Dict[str, Any]] = None # {"pos_delta": [1,0,0], "health_delta": -2, ...} + + # === END WEALTH === + + # Identity + id: str = field(default_factory=lambda: hashlib.sha256(str(time.time()).encode()).hexdigest()[:16]) + timestamp: float = field(default_factory=time.time) + + # Merkle linkage + parent_merkle: Optional[str] = None # Previous hold point + merkle_root: Optional[str] = None # Computed on creation + + # State + state: HoldState = HoldState.PENDING + + def __post_init__(self): + """Compute merkle root on creation.""" + if self.merkle_root is None: + data = f"{self.id}:{self.brain_id}:{self.action_probs.tobytes().hex()}:{self.timestamp}" + if self.parent_merkle: + data = f"{self.parent_merkle}:{data}" + self.merkle_root = hashlib.sha256(data.encode()).hexdigest()[:16] + + @property + def ai_choice(self) -> int: + """What the AI would choose.""" + return int(np.argmax(self.action_probs)) + + @property + def ai_confidence(self) -> float: + """Confidence in AI's top choice.""" + return float(np.max(self.action_probs)) + + def to_dict(self) -> Dict[str, Any]: + """Serialize for CASCADE observation - includes full informational wealth.""" + d = { + 'id': self.id, + 'brain_id': self.brain_id, + 'action_probs': self.action_probs.tolist(), + 'ai_choice': self.ai_choice, + 'ai_confidence': self.ai_confidence, + 'value': self.value, + 'timestamp': self.timestamp, + 'merkle_root': self.merkle_root, + 'parent_merkle': self.parent_merkle, + 'state': self.state.value, + 'observation': self.observation, + } + + # Include all available wealth + if self.action_labels is not None: + d['action_labels'] = self.action_labels + if self.latent is not None: + d['latent'] = self.latent.tolist() if hasattr(self.latent, 'tolist') else self.latent + if self.attention is not None: + d['attention'] = self.attention + if self.features is not None: + d['features'] = self.features + if self.imagination is not None: + d['imagination'] = self.imagination + if self.logits is not None: + d['logits'] = self.logits.tolist() if hasattr(self.logits, 'tolist') else self.logits + if self.reasoning is not None: + d['reasoning'] = self.reasoning + if self.world_prediction is not None: + d['world_prediction'] = self.world_prediction + + return _sanitize(d) + + +@dataclass +class HoldResolution: + """ + The resolution of a hold point. + + Either the human accepted, overrode, or it timed out. + Links back to the hold point, forming a provenance chain. + """ + hold_point: HoldPoint # The hold that was resolved + action: int # Final action taken + + # Resolution details + was_override: bool # True if human overrode AI + override_source: Optional[str] = None # Who/what overrode ("human", "policy", etc.) + + # Timing + hold_duration: float = 0.0 # How long was held + timestamp: float = field(default_factory=time.time) + + # Merkle linkage + merkle_root: Optional[str] = None + + def __post_init__(self): + """Compute merkle root.""" + if self.merkle_root is None: + data = f"{self.hold_point.merkle_root}:{self.action}:{self.was_override}:{self.timestamp}" + self.merkle_root = hashlib.sha256(data.encode()).hexdigest()[:16] + + def to_dict(self) -> Dict[str, Any]: + """Serialize for CASCADE observation.""" + d = { + 'hold_id': self.hold_point.id, + 'hold_merkle': self.hold_point.merkle_root, + 'action': self.action, + 'ai_choice': self.hold_point.ai_choice, + 'was_override': self.was_override, + 'override_source': self.override_source, + 'hold_duration': self.hold_duration, + 'merkle_root': self.merkle_root, + 'timestamp': self.timestamp, + } + return _sanitize(d) + + +class Hold: + """ + The HOLD system - manages inference-level halts. + + Singleton pattern - one Hold system per process. + + Usage: + hold = Hold.get() + + # Register listeners (for UI, visualization, etc.) + hold.register_listener(my_callback) + + # From within a brain's forward() method: + resolution = hold.yield_point( + action_probs=probs, + value=value, + observation=obs, + brain_id="brain_001" + ) + # Blocks until resolution! + + # From UI/control thread: + hold.accept() # or + hold.override(action=3, source="human") + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + # State + self._current_hold: Optional[HoldPoint] = None + self._resolution_event = threading.Event() + self._resolution: Optional[HoldResolution] = None + + # Chain + self._last_merkle: Optional[str] = None + self._hold_count = 0 + self._override_count = 0 + + # Callbacks - interfaces register here to receive hold points + self._listeners: List[Callable[[HoldPoint], None]] = [] + + # Settings + self.timeout: float = 30.0 # Default timeout (seconds) + self.auto_accept: bool = False # If True, don't block, just observe + + # CASCADE graph for this session + self._causation_graph = CausationGraph() + + self._initialized = True + print("[HOLD] system initialized (cascade-lattice)") + + @classmethod + def get(cls) -> 'Hold': + """Get the singleton instance.""" + return cls() + + def register_listener(self, callback: Callable[[HoldPoint], None]): + """ + Register a listener for hold points. + + The callback receives HoldPoint when inference halts. + Use this to connect visualizations, UIs, etc. + """ + self._listeners.append(callback) + print(f"[REGISTER] Registered HOLD listener: {callback.__name__ if hasattr(callback, '__name__') else callback}") + + def unregister_listener(self, callback: Callable): + """Remove a listener.""" + if callback in self._listeners: + self._listeners.remove(callback) + + def yield_point( + self, + action_probs: np.ndarray, + value: float, + observation: Dict[str, Any], + brain_id: str, + # === INFORMATIONAL WEALTH === + action_labels: Optional[List[str]] = None, + latent: Optional[np.ndarray] = None, + attention: Optional[Dict[str, float]] = None, + features: Optional[Dict[str, float]] = None, + imagination: Optional[Dict[int, Dict]] = None, + logits: Optional[np.ndarray] = None, + reasoning: Optional[List[str]] = None, + world_prediction: Optional[Dict[str, Any]] = None, + # === END WEALTH === + blocking: bool = True, + ) -> HoldResolution: + """ + Create a hold point and yield for resolution. + + This is called from within a brain's forward() method. + Blocks until resolved (or timeout). + + Args: + action_probs: The decision matrix (probability distribution) + value: Predicted value + observation: What the brain observed + brain_id: Identifier for the brain + + INFORMATIONAL WEALTH (all optional, but improves human understanding): + action_labels: Names for each action ["FORWARD", "BACK", "LEFT", ...] + latent: Model's latent state/activations + attention: Attention weights {"position": 0.7, "health": 0.3} + features: Feature activations {"spatial": 0.8, "danger": 0.2} + imagination: Per-action predictions {0: {"trajectory": [...], "expected_value": 0.5}} + logits: Raw pre-softmax logits + reasoning: Text explanations ["High reward expected", ...] + world_prediction: World model predictions {"pos_delta": [1,0,0]} + + blocking: If False, returns immediately with AI choice + + Returns: + HoldResolution with the final action + """ + # Create hold point with full wealth + hold = HoldPoint( + action_probs=action_probs, + value=value, + observation=observation, + brain_id=brain_id, + action_labels=action_labels, + latent=latent, + attention=attention, + features=features, + imagination=imagination, + logits=logits, + reasoning=reasoning, + world_prediction=world_prediction, + parent_merkle=self._last_merkle, + ) + + # Observe the hold point in CASCADE + sdk_observe( + model_id=brain_id, + input_data=observation, + output_data={**hold.to_dict(), 'event_type': 'hold_point'}, + ) + + self._hold_count += 1 + + # Non-blocking mode - just observe and return AI choice + if not blocking or self.auto_accept: + resolution = HoldResolution( + hold_point=hold, + action=hold.ai_choice, + was_override=False, + hold_duration=0.0, + ) + self._observe_resolution(resolution) + return resolution + + # Set as current hold + self._current_hold = hold + self._resolution_event.clear() + self._resolution = None + + # Notify listeners + for listener in self._listeners: + try: + listener(hold) + except Exception as e: + print(f"⚠️ HOLD listener error: {e}") + + # Print hold info + print(f"\n{'═' * 50}") + print(f"🛑 HOLD #{self._hold_count}") + print(f" Merkle: {hold.merkle_root}") + ai_label = hold.action_labels[hold.ai_choice] if hold.action_labels else str(hold.ai_choice) + print(f" AI Choice: {ai_label} (confidence: {hold.ai_confidence:.2%})") + print(f" Value: {hold.value:.4f}") + + # Show probabilities with labels + if hold.action_labels: + prob_str = ', '.join(f'{hold.action_labels[i]}:{p:.2f}' for i, p in enumerate(hold.action_probs)) + else: + prob_str = ', '.join(f'{i}:{p:.2f}' for i, p in enumerate(hold.action_probs)) + print(f" Probabilities: {prob_str}") + + # Show available wealth + wealth = [] + if hold.latent is not None: wealth.append("latent") + if hold.attention is not None: wealth.append("attention") + if hold.features is not None: wealth.append("features") + if hold.imagination is not None: wealth.append("imagination") + if hold.reasoning is not None: wealth.append("reasoning") + if wealth: + print(f" Wealth: {', '.join(wealth)}") + + print(f" Waiting for resolution (timeout: {self.timeout}s)...") + print(f"{'═' * 50}") + + # Block until resolution or timeout + start_time = time.time() + resolved = self._resolution_event.wait(timeout=self.timeout) + hold_duration = time.time() - start_time + + if resolved and self._resolution: + resolution = self._resolution + resolution.hold_duration = hold_duration + else: + # Timeout - use AI choice + hold.state = HoldState.TIMEOUT + resolution = HoldResolution( + hold_point=hold, + action=hold.ai_choice, + was_override=False, + override_source="timeout", + hold_duration=hold_duration, + ) + print(f"[TIMEOUT] HOLD timeout - accepting AI choice: {hold.ai_choice}") + + # Observe resolution + self._observe_resolution(resolution) + + # Clear state + self._current_hold = None + self._resolution = None + + return resolution + + def resolve(self, action: int, source: str = "human"): + """ + Resolve the current hold with an action. + + Called by UI/interface when human makes a choice. + + Args: + action: The chosen action + source: Who resolved it ("human", "policy", etc.) + """ + if self._current_hold is None: + print("[WARN] No active hold to resolve") + return + + hold = self._current_hold + was_override = (action != hold.ai_choice) + + if was_override: + hold.state = HoldState.OVERRIDDEN + self._override_count += 1 + else: + hold.state = HoldState.ACCEPTED + + self._resolution = HoldResolution( + hold_point=hold, + action=action, + was_override=was_override, + override_source=source if was_override else None, + ) + + print(f"[RESOLVE] HOLD resolved: action={action}, override={was_override}") + self._resolution_event.set() + + def accept(self): + """Accept AI's choice for current hold.""" + if self._current_hold: + self.resolve(self._current_hold.ai_choice, source="accept") + + def override(self, action: int, source: str = "human"): + """Override with a different action.""" + self.resolve(action, source) + + def cancel(self): + """Cancel current hold without resolution.""" + if self._current_hold: + self._current_hold.state = HoldState.CANCELLED + self._resolution = HoldResolution( + hold_point=self._current_hold, + action=self._current_hold.ai_choice, + was_override=False, + override_source="cancelled", + ) + self._resolution_event.set() + + def _observe_resolution(self, resolution: HoldResolution): + """Record resolution to CASCADE.""" + sdk_observe( + model_id=resolution.hold_point.brain_id, + input_data=resolution.hold_point.to_dict(), + output_data={**resolution.to_dict(), 'event_type': 'hold_resolution'}, + ) + + # Update chain + self._last_merkle = resolution.merkle_root + + # Add to causation graph + link = CausationLink( + from_event=resolution.hold_point.merkle_root, + to_event=resolution.merkle_root, + causation_type="hold_resolved", + strength=1.0 if resolution.was_override else 0.5, + explanation=f"Override: {resolution.was_override}, Action: {resolution.action}", + ) + self._causation_graph.add_link(link) + + @property + def current_hold(self) -> Optional[HoldPoint]: + """Get current active hold point (if any).""" + return self._current_hold + + @property + def stats(self) -> Dict[str, Any]: + """Get hold statistics.""" + return { + 'total_holds': self._hold_count, + 'overrides': self._override_count, + 'override_rate': self._override_count / max(self._hold_count, 1), + 'last_merkle': self._last_merkle, + } + + +class HoldAwareMixin: + """ + Mixin for brains that support HOLD. + + Add this to your Brain class to enable inference-level halts. + + Usage: + class MyBrain(HoldAwareMixin, BaseBrain): + def forward(self, inputs): + # Your inference code + return {"action_probs": probs, "value": value} + + brain = MyBrain() + brain.enable_hold() + + # Now forward_with_hold() will pause for human input + output = brain.forward_with_hold(inputs) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._hold_system = Hold.get() + self._hold_enabled = True + self._brain_id = getattr(self, 'id', hashlib.sha256(str(id(self)).encode()).hexdigest()[:16]) + + def forward_with_hold( + self, + inputs: Dict[str, Any], + blocking: bool = True, + ) -> Dict[str, Any]: + """ + Forward pass with HOLD support. + + Call this instead of forward() to enable hold points. + """ + # Get decision matrix from normal forward + output = self.forward(inputs) + + if not self._hold_enabled: + return output + + action_probs = output.get('action_probs', None) + if action_probs is None: + return output + + # Get imagination if available (DreamerBrain, etc.) + imagined = None + if hasattr(self, 'imagine'): + try: + imagined = self.imagine(horizon=15) + except: + pass + + # Yield to hold system + resolution = self._hold_system.yield_point( + action_probs=np.array(action_probs), + value=float(output.get('value', 0.0)), + observation=inputs, + brain_id=self._brain_id, + imagined_futures=imagined, + blocking=blocking, + ) + + # Update output with resolved action + output['action'] = resolution.action + output['hold_resolution'] = resolution.to_dict() + output['was_override'] = resolution.was_override + + return output + + def enable_hold(self): + """Enable HOLD for this brain.""" + self._hold_enabled = True + + def disable_hold(self): + """Disable HOLD (normal inference).""" + self._hold_enabled = False + + +# Demo +def _demo_hold(): + """Demonstrate HOLD system.""" + print("=" * 60) + print("HOLD SYSTEM DEMO") + print("=" * 60) + + # Get hold system + hold = Hold.get() + hold.timeout = 10.0 + + def on_hold(point: HoldPoint): + print(f"\n🔔 Listener received hold: {point.id}") + + hold.register_listener(on_hold) + + def brain_loop(): + for step in range(3): + probs = np.random.dirichlet(np.ones(8)) + resolution = hold.yield_point( + action_probs=probs, + value=np.random.random(), + observation={'step': step}, + brain_id='demo_brain', + ) + print(f"Brain received: action={resolution.action}, override={resolution.was_override}") + + def human_input(): + for i in range(3): + time.sleep(2) + if hold.current_hold: + if i % 2 == 0: + hold.accept() + else: + hold.override(7, source="demo_human") + + brain_thread = threading.Thread(target=brain_loop) + human_thread = threading.Thread(target=human_input) + + brain_thread.start() + human_thread.start() + + brain_thread.join() + human_thread.join() + + print(f"\n{'=' * 60}") + print("SESSION STATS") + print(hold.stats) + + +if __name__ == "__main__": + _demo_hold() diff --git a/cascade/hold/session.py b/cascade/hold/session.py new file mode 100644 index 0000000000000000000000000000000000000000..a4eb42a4725ea42bcbc9ab964cb07df656e2180d --- /dev/null +++ b/cascade/hold/session.py @@ -0,0 +1,707 @@ +""" +HOLD Session - Arcade-Style Inference Interception +══════════════════════════════════════════════════════════ + +"Pause the machine. See what it sees. Choose what it chooses." + +The arcade layer of HOLD: +- CausationHold: Session management with history +- InferenceStep: Single crystallized moment +- Time travel via state snapshots +- Speed controls and combo tracking + +Controls: + SPACE - Accept model's choice, advance + 1-9 - Override with alternative + ←/→ - Step back/forward through history + +/- - Speed up/slow down auto-advance + P - Pause/unpause auto-advance + ESC - Exit hold mode +""" + +import numpy as np +import time +import json +import hashlib +import threading +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Any, Callable, Tuple +from datetime import datetime +from pathlib import Path +from enum import Enum + + +class SessionState(Enum): + """Current state of the hold session.""" + IDLE = "idle" # Not holding anything + PAUSED = "paused" # Frozen, waiting for input + STEPPING = "stepping" # Auto-advancing at set speed + REWINDING = "rewinding" # Going backwards through history + + +@dataclass +class InferenceStep: + """A single crystallized moment of inference.""" + step_id: str + step_index: int + timestamp: float + + # What the model sees + input_context: Dict[str, Any] + + # What the model wants to do + candidates: List[Dict[str, Any]] # [{value, probability, metadata}] + top_choice: Any + top_probability: float + + # Internal state snapshot (for true rewind) + hidden_state: Optional[np.ndarray] = None + attention_weights: Optional[Dict[str, float]] = None + + # What actually happened + chosen_value: Any = None + was_override: bool = False + override_by: str = "model" # "model" or "human" + + # Provenance + cascade_hash: Optional[str] = None + + # Private: full state snapshot for true rewind + _state_snapshot: Optional[Dict[str, Any]] = field(default=None, repr=False) + + +@dataclass +class HoldSession: + """A complete hold session with history.""" + session_id: str + agent_id: str + started_at: float + + # All steps in order + steps: List[InferenceStep] = field(default_factory=list) + current_index: int = 0 + + # Arcade stats + total_steps: int = 0 + human_overrides: int = 0 + correct_predictions: int = 0 # Human guessed what model would do + combo: int = 0 + max_combo: int = 0 + + # Speed control (steps per second, 0 = manual only) + speed_level: int = 0 # 0=manual, 1=slow, 2=medium, 3=fast, 4=ludicrous + speed_map: Dict[int, float] = field(default_factory=lambda: { + 0: 0.0, # Manual + 1: 0.5, # 2 sec per step + 2: 1.0, # 1 sec per step + 3: 2.0, # 0.5 sec per step + 4: 10.0, # 0.1 sec per step (ludicrous speed) + }) + + # State + state: SessionState = SessionState.IDLE + + +@dataclass +class ArcadeFeedback: + """Visual/audio feedback cues.""" + message: str + intensity: float # 0-1, for glow/shake/etc + sound_cue: str # "accept", "override", "combo", "combo_break", "rewind" + color: Tuple[int, int, int] = (255, 255, 255) + + +class CausationHold: + """ + The arcade-layer hold system. Wraps any inference function. + + Features: + - Session management with full history + - True state restoration for time travel + - Speed controls (manual to ludicrous) + - Combo tracking and high scores + + Usage: + hold = CausationHold() + + # Start a session + hold.begin_session(agent_id="agent_123") + + # In inference loop: + for step in inference_steps: + choice, feedback = hold.capture( + input_context={"tokens": tokens}, + candidates=[{"value": "A", "probability": 0.8}, ...] + ) # Pauses here until user input! + + # Time travel + hold.rewind(steps=3) + hold.branch_from(step_index=5, choice_index=2) + + stats = hold.end_session() + """ + + def __init__(self, cascade_bus=None): + """ + Args: + cascade_bus: Optional CASCADE event bus for provenance + """ + self.bus = cascade_bus + self.session: Optional[HoldSession] = None + self.callbacks: Dict[str, List[Callable]] = { + 'on_step': [], + 'on_override': [], + 'on_combo': [], + 'on_combo_break': [], + 'on_rewind': [], + 'on_state_restore': [], + } + + # Thread safety + self._lock = threading.Lock() + self._input_event = threading.Event() + self._user_choice: Optional[Any] = None + + # High scores (persisted) + self.high_scores_path = Path("data/hold_high_scores.json") + self.high_scores = self._load_high_scores() + + # ======================================================================== + # SESSION MANAGEMENT + # ======================================================================== + + def begin_session(self, agent_id: str) -> HoldSession: + """Start a new hold session.""" + session_id = f"hold_{agent_id}_{int(time.time()*1000)}" + + self.session = HoldSession( + session_id=session_id, + agent_id=agent_id, + started_at=time.time(), + ) + self.session.state = SessionState.PAUSED + + self._emit_cascade("hold_session_start", { + "session_id": session_id, + "agent_id": agent_id, + }) + + return self.session + + def end_session(self) -> Dict[str, Any]: + """End session and return stats.""" + if not self.session: + return {} + + stats = { + "session_id": self.session.session_id, + "agent_id": self.session.agent_id, + "duration": time.time() - self.session.started_at, + "total_steps": self.session.total_steps, + "human_overrides": self.session.human_overrides, + "correct_predictions": self.session.correct_predictions, + "max_combo": self.session.max_combo, + "accuracy": ( + self.session.correct_predictions / max(1, self.session.total_steps) + ), + } + + # Check for high score + self._check_high_score(stats) + + self._emit_cascade("hold_session_end", stats) + + self.session = None + return stats + + # ======================================================================== + # CAPTURE & ADVANCE - WITH STATE SNAPSHOT FOR TRUE REWIND + # ======================================================================== + + def capture( + self, + input_context: Dict[str, Any], + candidates: List[Dict[str, Any]], + hidden_state: Optional[np.ndarray] = None, + attention: Optional[Dict[str, float]] = None, + state_snapshot: Optional[Dict[str, Any]] = None, + ) -> Tuple[Any, ArcadeFeedback]: + """ + Capture an inference step. BLOCKS until user input or auto-advance. + + IMPORTANT: Pass state_snapshot for true rewind capability. + This should be a complete snapshot of the model's internal state + that can be restored to allow execution from this decision point + with a different choice. + + This is NOT prediction - you will ACTUALLY execute the choice and + see REAL outcomes. If you don't like them, rewind and try again. + + Args: + input_context: What the model is looking at + candidates: List of {value, probability, ...} options + hidden_state: Optional internal state snapshot (deprecated, use state_snapshot) + attention: Optional attention weights + state_snapshot: Complete model state for TRUE rewind capability + + Returns: + (chosen_value, feedback) - The value to use and arcade feedback + """ + if not self.session: + # No session = passthrough, just return top choice + return candidates[0]['value'], ArcadeFeedback("", 0, "") + + # Sort candidates by probability + candidates = sorted(candidates, key=lambda x: x.get('probability', 0), reverse=True) + top = candidates[0] + + # Merge hidden_state into state_snapshot if provided separately + if state_snapshot is None and hidden_state is not None: + state_snapshot = {'hidden_state': hidden_state} + elif state_snapshot is not None and hidden_state is not None: + state_snapshot['hidden_state'] = hidden_state + + # Create step - this is a CHECKPOINT for true rewind + step = InferenceStep( + step_id=f"step_{self.session.total_steps}", + step_index=self.session.total_steps, + timestamp=time.time(), + input_context=input_context, + candidates=candidates, + top_choice=top['value'], + top_probability=top.get('probability', 1.0), + hidden_state=hidden_state, + attention_weights=attention, + ) + + # Store state snapshot for TRUE rewind (not just history navigation) + if state_snapshot is not None: + step._state_snapshot = state_snapshot + + # Compute merkle hash for provenance + step.cascade_hash = self._compute_step_hash(step) + + # Add to history + with self._lock: + self.session.steps.append(step) + self.session.current_index = len(self.session.steps) - 1 + self.session.total_steps += 1 + + # Emit step event + self._emit_callback('on_step', step) + self._emit_cascade("hold_step", { + "step_index": step.step_index, + "top_choice": str(top['value']), + "top_prob": top.get('probability', 1.0), + "num_candidates": len(candidates), + "has_snapshot": state_snapshot is not None, + "merkle": step.cascade_hash, + }) + + # Wait for input + choice, feedback = self._wait_for_input(step) + + # Record what happened + step.chosen_value = choice + step.was_override = (choice != top['value']) + step.override_by = "human" if step.was_override else "model" + + if step.was_override: + self.session.human_overrides += 1 + self._emit_callback('on_override', step, choice) + + return choice, feedback + + def _wait_for_input(self, step: InferenceStep) -> Tuple[Any, ArcadeFeedback]: + """Wait for user input or auto-advance timer.""" + + # Manual mode = wait indefinitely + if self.session.speed_level == 0: + self._input_event.clear() + self._input_event.wait() # Blocks until input() + + choice = self._user_choice + self._user_choice = None + + else: + # Auto-advance mode + speed = self.session.speed_map[self.session.speed_level] + wait_time = 1.0 / speed if speed > 0 else float('inf') + + self._input_event.clear() + got_input = self._input_event.wait(timeout=wait_time) + + if got_input and self._user_choice is not None: + choice = self._user_choice + self._user_choice = None + else: + # Auto-accepted + choice = step.top_choice + + # Generate feedback + return choice, self._generate_feedback(step, choice) + + def input(self, choice: Any): + """ + Provide user input. Call from UI thread. + + Args: + choice: The value to use (or index into candidates) + """ + if not self.session: + return + + current_step = self.session.steps[self.session.current_index] + + # Handle index input (1-9 keys) + if isinstance(choice, int) and 0 <= choice < len(current_step.candidates): + choice = current_step.candidates[choice]['value'] + + self._user_choice = choice + self._input_event.set() + + def accept(self): + """Accept model's top choice (SPACE key).""" + if not self.session or not self.session.steps: + return + + current = self.session.steps[self.session.current_index] + self.input(current.top_choice) + + def override(self, index: int): + """Override with candidate at index (1-9 keys).""" + self.input(index) + + # ======================================================================== + # NAVIGATION (TIME TRAVEL) - TRUE STATE RESTORATION + # ======================================================================== + + def rewind(self, steps: int = 1, restore_state: bool = True) -> Optional[InferenceStep]: + """ + Go back in history with optional state restoration. + + This is NOT simulation - we actually restore the model's internal state + to the snapshot taken at that decision point. From there, you can + execute a different branch and see REAL outcomes. + + Args: + steps: Number of steps to go back + restore_state: If True, actually restore hidden_state to model + + Returns: + The step we rewound to + """ + if not self.session: + return None + + with self._lock: + new_index = max(0, self.session.current_index - steps) + if new_index != self.session.current_index: + self.session.current_index = new_index + self.session.state = SessionState.REWINDING + + step = self.session.steps[new_index] + + # TRUE STATE RESTORATION + if restore_state and step.hidden_state is not None: + self._restore_state(step) + + self._emit_callback('on_rewind', step, -steps) + + return step + return None + + def _restore_state(self, step: InferenceStep): + """ + Restore model state from a snapshot. + + This is the key that makes execution + rewind possible. + The model's internal state is set back to exactly what it was + at this decision point, allowing you to branch differently. + """ + if step.hidden_state is None and step._state_snapshot is None: + return + + # Emit state restoration event - hooked components can restore themselves + self._emit_callback('on_state_restore', step) + self._emit_cascade("state_restored", { + "step_index": step.step_index, + "merkle": step.cascade_hash, + "had_hidden_state": step.hidden_state is not None, + "had_snapshot": step._state_snapshot is not None, + }) + + def branch_from(self, step_index: int, choice_index: int) -> Optional[InferenceStep]: + """ + Rewind to a step and immediately choose a different branch. + + This is the core gameplay loop: + 1. Rewind to decision point + 2. Choose different option + 3. Execute and see what happens + 4. Repeat until satisfied + + Args: + step_index: Which decision point to branch from + choice_index: Which candidate to choose (0 = model's choice) + + Returns: + The step after branching (with state restored) + """ + step = self.jump_to(step_index) + if step is None: + return None + + # Restore state + self._restore_state(step) + + # Set up the override + if choice_index < len(step.candidates): + self.override(choice_index) + else: + self.accept() + + return step + + def forward(self, steps: int = 1) -> Optional[InferenceStep]: + """Go forward in history (if we've rewound).""" + if not self.session: + return None + + with self._lock: + max_index = len(self.session.steps) - 1 + new_index = min(max_index, self.session.current_index + steps) + if new_index != self.session.current_index: + self.session.current_index = new_index + + step = self.session.steps[new_index] + self._emit_callback('on_rewind', step, steps) + + return step + return None + + def jump_to(self, index: int) -> Optional[InferenceStep]: + """Jump to specific step.""" + if not self.session: + return None + + with self._lock: + index = max(0, min(index, len(self.session.steps) - 1)) + self.session.current_index = index + return self.session.steps[index] + + # ======================================================================== + # SPEED CONTROL + # ======================================================================== + + def speed_up(self): + """Increase auto-advance speed.""" + if self.session: + self.session.speed_level = min(4, self.session.speed_level + 1) + + def speed_down(self): + """Decrease auto-advance speed.""" + if self.session: + self.session.speed_level = max(0, self.session.speed_level - 1) + + def set_speed(self, level: int): + """Set speed level directly (0-4).""" + if self.session: + self.session.speed_level = max(0, min(4, level)) + + def pause(self): + """Pause auto-advance.""" + if self.session: + self.session.state = SessionState.PAUSED + + def unpause(self): + """Resume auto-advance.""" + if self.session: + self.session.state = SessionState.STEPPING + + # ======================================================================== + # PROVENANCE HASHING + # ======================================================================== + + def _compute_step_hash(self, step: InferenceStep) -> str: + """ + Compute merkle hash for a step. + + This hash uniquely identifies this decision point and allows + verification that rewind is restoring to the exact right state. + """ + # Include parent hash for chain integrity + parent_hash = "" + if self.session and len(self.session.steps) > 0: + prev_step = self.session.steps[-1] + parent_hash = prev_step.cascade_hash or "" + + content = json.dumps({ + 'step_index': step.step_index, + 'timestamp': step.timestamp, + 'top_choice': str(step.top_choice), + 'top_prob': step.top_probability, + 'num_candidates': len(step.candidates), + 'parent_hash': parent_hash, + }, sort_keys=True) + + return hashlib.sha256(content.encode()).hexdigest()[:16] + + # ======================================================================== + # ARCADE FEEDBACK + # ======================================================================== + + def _generate_feedback(self, step: InferenceStep, choice: Any) -> ArcadeFeedback: + """Generate arcade-style feedback for a step.""" + + is_override = (choice != step.top_choice) + + if is_override: + # Combo break! + if self.session.combo > 0: + self._emit_callback('on_combo_break', self.session.combo) + + self.session.combo = 0 + + return ArcadeFeedback( + message="OVERRIDE", + intensity=0.8, + sound_cue="override", + color=(255, 165, 0), # Orange + ) + + else: + # Accepted model choice + self.session.combo += 1 + self.session.max_combo = max(self.session.max_combo, self.session.combo) + + # Combo milestones + if self.session.combo in [10, 25, 50, 100]: + self._emit_callback('on_combo', self.session.combo) + return ArcadeFeedback( + message=f"COMBO x{self.session.combo}!", + intensity=1.0, + sound_cue="combo", + color=(0, 255, 255), # Cyan + ) + + # Regular accept + return ArcadeFeedback( + message="", + intensity=0.3 + min(0.5, self.session.combo * 0.02), + sound_cue="accept", + color=(0, 255, 0), # Green + ) + + # ======================================================================== + # CALLBACKS + # ======================================================================== + + def on(self, event: str, callback: Callable): + """Register callback for events.""" + if event in self.callbacks: + self.callbacks[event].append(callback) + + def _emit_callback(self, event: str, *args): + """Emit event to callbacks.""" + for cb in self.callbacks.get(event, []): + try: + cb(*args) + except Exception as e: + print(f"Callback error: {e}") + + # ======================================================================== + # CASCADE PROVENANCE + # ======================================================================== + + def _emit_cascade(self, event_type: str, data: Dict[str, Any]): + """Emit event to CASCADE bus if available.""" + if self.bus: + try: + self.bus.emit(event_type, { + **data, + "source": "causation_hold", + "timestamp": time.time(), + }) + except Exception: + pass + + # ======================================================================== + # HIGH SCORES + # ======================================================================== + + def _load_high_scores(self) -> Dict[str, Any]: + """Load high scores from disk.""" + if self.high_scores_path.exists(): + try: + return json.loads(self.high_scores_path.read_text()) + except Exception: + pass + return {"max_combo": 0, "best_accuracy": 0.0, "total_sessions": 0} + + def _save_high_scores(self): + """Save high scores to disk.""" + self.high_scores_path.parent.mkdir(parents=True, exist_ok=True) + self.high_scores_path.write_text(json.dumps(self.high_scores, indent=2)) + + def _check_high_score(self, stats: Dict[str, Any]): + """Check and update high scores.""" + updated = False + + if stats['max_combo'] > self.high_scores['max_combo']: + self.high_scores['max_combo'] = stats['max_combo'] + updated = True + + if stats['accuracy'] > self.high_scores['best_accuracy']: + self.high_scores['best_accuracy'] = stats['accuracy'] + updated = True + + self.high_scores['total_sessions'] += 1 + + if updated: + self._save_high_scores() + + # ======================================================================== + # DECORATOR FOR EASY WRAPPING + # ======================================================================== + + def intercept(self, granularity: str = "step"): + """ + Decorator to intercept a function's inference. + + Args: + granularity: "step" (each call) or "token" (if function yields) + """ + def decorator(func): + def wrapper(*args, **kwargs): + # If no session, passthrough + if not self.session: + return func(*args, **kwargs) + + # Capture the input + input_context = { + "args": str(args)[:200], + "kwargs": {k: str(v)[:100] for k, v in kwargs.items()}, + } + + # Get result + result = func(*args, **kwargs) + + # Create candidates from result + if isinstance(result, np.ndarray): + # For embeddings, show top dimensions + top_dims = np.argsort(np.abs(result.flatten()))[-5:][::-1] + candidates = [ + {"value": f"dim_{d}", "probability": float(np.abs(result.flatten()[d]))} + for d in top_dims + ] + else: + candidates = [{"value": result, "probability": 1.0}] + + # Capture (may block) + choice, feedback = self.capture(input_context, candidates) + + return result + + return wrapper + return decorator diff --git a/cascade/identity.py b/cascade/identity.py new file mode 100644 index 0000000000000000000000000000000000000000..c85d0dc2f0622e5093fb5c9a7f9aa35422941676 --- /dev/null +++ b/cascade/identity.py @@ -0,0 +1,715 @@ +""" +CASCADE Model Identity Layer + +Canonical identification for any AI model variant: +- Base models (meta-llama/Llama-3-8B) +- Quantizations (Q4_K_M, Q8_0, AWQ, GPTQ) +- Fine-tunes (LoRA, full, RLHF) +- API endpoints (behavioral fingerprinting) + +Every unique model gets a node in the lattice. +Every observation links to its model's node. +The lattice becomes the collective memory of AI behavior. + +"Same name, different model, different behavior." +""" + +import hashlib +import json +import time +from pathlib import Path +from dataclasses import dataclass, field, asdict +from typing import Optional, List, Dict, Any +from enum import Enum + + +class ModelFormat(Enum): + """Model weight formats.""" + SAFETENSORS = "safetensors" + PYTORCH = "pytorch" + GGUF = "gguf" + GGML = "ggml" + ONNX = "onnx" + TENSORRT = "tensorrt" + OPENVINO = "openvino" + COREML = "coreml" + API = "api" # No weights, just endpoint + UNKNOWN = "unknown" + + +class QuantizationType(Enum): + """Quantization methods.""" + NONE = "none" # FP32/FP16/BF16 + GGUF_Q4_0 = "Q4_0" + GGUF_Q4_K_M = "Q4_K_M" + GGUF_Q4_K_S = "Q4_K_S" + GGUF_Q5_0 = "Q5_0" + GGUF_Q5_K_M = "Q5_K_M" + GGUF_Q5_K_S = "Q5_K_S" + GGUF_Q6_K = "Q6_K" + GGUF_Q8_0 = "Q8_0" + GPTQ_4BIT = "GPTQ-4bit" + GPTQ_8BIT = "GPTQ-8bit" + AWQ_4BIT = "AWQ-4bit" + BITSANDBYTES_4BIT = "bnb-4bit" + BITSANDBYTES_8BIT = "bnb-8bit" + INT8 = "INT8" + INT4 = "INT4" + CUSTOM = "custom" + + +class FineTuneType(Enum): + """Fine-tuning methods.""" + NONE = "none" + LORA = "lora" + QLORA = "qlora" + FULL = "full" + RLHF = "rlhf" + DPO = "dpo" + ORPO = "orpo" + CUSTOM = "custom" + + +@dataclass +class ModelVariant: + """Describes how a model differs from its base.""" + quantization: str = "none" + format: str = "unknown" + bits: Optional[int] = None + provider: Optional[str] = None # Who made this variant (e.g., "TheBloke") + + def to_dict(self) -> dict: + return asdict(self) + + +@dataclass +class FineTuneInfo: + """Describes fine-tuning applied to a model.""" + type: str = "none" + adapter_id: Optional[str] = None # HuggingFace adapter ID + adapter_hash: Optional[str] = None # Hash of adapter weights + base_model_root: Optional[str] = None # Merkle root of base model identity + dataset_id: Optional[str] = None # Training dataset + + def to_dict(self) -> dict: + return asdict(self) + + +@dataclass +class BehavioralFingerprint: + """ + Fingerprint for API models where weights are unavailable. + Generated by running standard probes and hashing responses. + """ + probe_responses: List[Dict[str, Any]] = field(default_factory=list) + probe_hash: Optional[str] = None + fingerprint_version: int = 1 + generated_at: Optional[float] = None + + def to_dict(self) -> dict: + return asdict(self) + + +@dataclass +class ModelIdentity: + """ + Canonical identity for any AI model variant. + + This is the node that goes in the lattice. + All observations of this model link to this identity. + """ + # === Core Identity === + base_model: str # HuggingFace ID or canonical name + model_id: str # Full unique identifier (computed) + + # === Variant Info === + variant: ModelVariant = field(default_factory=ModelVariant) + fine_tune: FineTuneInfo = field(default_factory=FineTuneInfo) + + # === Cryptographic Identity === + weight_hash: Optional[str] = None # SHA256 of weights (if available) + config_hash: Optional[str] = None # SHA256 of model config + tokenizer_hash: Optional[str] = None # SHA256 of tokenizer + + # === Behavioral Fingerprint (for APIs) === + behavioral_fingerprint: Optional[BehavioralFingerprint] = None + + # === Source Info === + source_url: Optional[str] = None + source_revision: Optional[str] = None # Git commit/tag + downloaded_at: Optional[float] = None + + # === Lattice Info === + parent_root: Optional[str] = None # Genesis or base model's merkle root + merkle_root: Optional[str] = None # This identity's merkle root + created_at: float = field(default_factory=time.time) + + # === Metadata === + parameters: Optional[int] = None # Parameter count + context_length: Optional[int] = None + architecture: Optional[str] = None # "llama", "mistral", "gpt", etc. + license: Optional[str] = None + + def __post_init__(self): + """Compute derived fields.""" + if not self.model_id: + self.model_id = self.compute_model_id() + + def compute_model_id(self) -> str: + """ + Compute canonical model ID from components. + Format: base_model::variant_spec::fine_tune_spec + """ + parts = [self.base_model] + + # Add variant spec + if self.variant.quantization != "none": + parts.append(f"q:{self.variant.quantization}") + if self.variant.format != "unknown": + parts.append(f"fmt:{self.variant.format}") + if self.variant.provider: + parts.append(f"by:{self.variant.provider}") + + # Add fine-tune spec + if self.fine_tune.type != "none": + parts.append(f"ft:{self.fine_tune.type}") + if self.fine_tune.adapter_id: + parts.append(f"adapter:{self.fine_tune.adapter_id}") + + return "::".join(parts) + + def compute_merkle_root(self) -> str: + """Compute merkle root of this identity.""" + # Create canonical representation + canonical = { + "base_model": self.base_model, + "model_id": self.model_id, + "variant": self.variant.to_dict(), + "fine_tune": self.fine_tune.to_dict(), + "weight_hash": self.weight_hash, + "config_hash": self.config_hash, + "tokenizer_hash": self.tokenizer_hash, + "parent_root": self.parent_root, + "created_at": self.created_at, + } + + # Add behavioral fingerprint if present + if self.behavioral_fingerprint: + canonical["behavioral_fingerprint"] = self.behavioral_fingerprint.probe_hash + + # Hash it + canonical_json = json.dumps(canonical, sort_keys=True) + self.merkle_root = hashlib.sha256(canonical_json.encode()).hexdigest()[:16] + return self.merkle_root + + def finalize(self, parent_root: str = None): + """Finalize identity and compute merkle root.""" + if parent_root: + self.parent_root = parent_root + self.merkle_root = self.compute_merkle_root() + return self + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "base_model": self.base_model, + "model_id": self.model_id, + "variant": self.variant.to_dict(), + "fine_tune": self.fine_tune.to_dict(), + "weight_hash": self.weight_hash, + "config_hash": self.config_hash, + "tokenizer_hash": self.tokenizer_hash, + "behavioral_fingerprint": self.behavioral_fingerprint.to_dict() if self.behavioral_fingerprint else None, + "source_url": self.source_url, + "source_revision": self.source_revision, + "downloaded_at": self.downloaded_at, + "parent_root": self.parent_root, + "merkle_root": self.merkle_root, + "created_at": self.created_at, + "parameters": self.parameters, + "context_length": self.context_length, + "architecture": self.architecture, + "license": self.license, + } + + def to_chain_format(self) -> dict: + """Convert to provenance chain format for lattice storage.""" + return { + "session_id": f"model_identity_{self.merkle_root}", + "model_id": self.model_id, + "model_hash": self.weight_hash or self.behavioral_fingerprint.probe_hash if self.behavioral_fingerprint else "unknown", + "input_hash": self.base_model, + "output_hash": None, + "records": { + "identity": { + "layer_name": "identity", + "layer_idx": 0, + "state_hash": self.merkle_root, + "parent_hashes": [self.parent_root] if self.parent_root else [], + "params_hash": self.config_hash, + "shape": [self.parameters] if self.parameters else [0], + "dtype": "model_identity", + "stats": self.to_dict(), + "execution_order": 0, + "timestamp": self.created_at, + } + }, + "external_roots": [self.parent_root] if self.parent_root else [], + "merkle_root": self.merkle_root, + "created_at": self.created_at, + "finalized": True, + } + + +# ============================================================================= +# STANDARD PROBES FOR BEHAVIORAL FINGERPRINTING +# ============================================================================= + +STANDARD_PROBES_V1 = [ + # Deterministic probes (temperature=0) + { + "id": "math_simple", + "prompt": "What is 2+2? Answer with just the number.", + "params": {"temperature": 0, "max_tokens": 10}, + }, + { + "id": "capital_france", + "prompt": "Complete this sentence with one word: The capital of France is", + "params": {"temperature": 0, "max_tokens": 10}, + }, + { + "id": "translate_hello", + "prompt": "Translate to French: Hello", + "params": {"temperature": 0, "max_tokens": 20}, + }, + { + "id": "color_sky", + "prompt": "What color is the sky on a clear day? One word answer:", + "params": {"temperature": 0, "max_tokens": 10}, + }, + + # Capability probes + { + "id": "code_simple", + "prompt": "Write a Python function that adds two numbers. Just the function, no explanation.", + "params": {"temperature": 0, "max_tokens": 100}, + }, + { + "id": "reasoning", + "prompt": "If all cats are mammals and all mammals are animals, are all cats animals? Answer yes or no.", + "params": {"temperature": 0, "max_tokens": 10}, + }, + + # System prompt probe + { + "id": "system_role", + "prompt": "You are a helpful pirate. Say hello.", + "params": {"temperature": 0, "max_tokens": 50}, + "system": "You are a helpful pirate who speaks like a pirate.", + }, + + # Edge cases + { + "id": "empty", + "prompt": "", + "params": {"temperature": 0, "max_tokens": 50}, + }, + { + "id": "repetition", + "prompt": "Repeat after me exactly: The quick brown fox", + "params": {"temperature": 0, "max_tokens": 20}, + }, +] + + +def generate_behavioral_fingerprint( + call_fn, # Function that takes (prompt, params) and returns response + probes: List[dict] = None, + version: int = 1, +) -> BehavioralFingerprint: + """ + Generate behavioral fingerprint by running standard probes. + + Args: + call_fn: Function to call the model. Signature: (prompt, params) -> str + probes: List of probe configs. Defaults to STANDARD_PROBES_V1. + version: Fingerprint version number. + + Returns: + BehavioralFingerprint with hashed responses. + """ + if probes is None: + probes = STANDARD_PROBES_V1 + + responses = [] + for probe in probes: + try: + response = call_fn(probe["prompt"], probe.get("params", {})) + response_hash = hashlib.sha256(str(response).encode()).hexdigest()[:16] + except Exception as e: + response_hash = f"error:{type(e).__name__}" + + responses.append({ + "probe_id": probe["id"], + "prompt_hash": hashlib.sha256(probe["prompt"].encode()).hexdigest()[:16], + "response_hash": response_hash, + }) + + # Compute overall fingerprint hash + fingerprint_data = json.dumps(responses, sort_keys=True) + probe_hash = hashlib.sha256(fingerprint_data.encode()).hexdigest()[:16] + + return BehavioralFingerprint( + probe_responses=responses, + probe_hash=probe_hash, + fingerprint_version=version, + generated_at=time.time(), + ) + + +# ============================================================================= +# MODEL IDENTITY FACTORY +# ============================================================================= + +def detect_quantization(model_path: str) -> str: + """Detect quantization from model path or name.""" + path_lower = model_path.lower() + + # GGUF quantizations + for q in ["q4_k_m", "q4_k_s", "q4_0", "q5_k_m", "q5_k_s", "q5_0", "q6_k", "q8_0"]: + if q in path_lower: + return q.upper() + + # GPTQ + if "gptq" in path_lower: + if "4bit" in path_lower or "-4b" in path_lower: + return "GPTQ-4bit" + elif "8bit" in path_lower or "-8b" in path_lower: + return "GPTQ-8bit" + return "GPTQ" + + # AWQ + if "awq" in path_lower: + return "AWQ-4bit" + + # BitsAndBytes + if "bnb" in path_lower or "bitsandbytes" in path_lower: + if "4bit" in path_lower: + return "bnb-4bit" + return "bnb-8bit" + + return "none" + + +def detect_format(model_path: str) -> str: + """Detect model format from path.""" + path_lower = model_path.lower() + + if ".gguf" in path_lower: + return "gguf" + elif ".ggml" in path_lower: + return "ggml" + elif ".safetensors" in path_lower or "safetensors" in path_lower: + return "safetensors" + elif ".onnx" in path_lower: + return "onnx" + elif ".bin" in path_lower or "pytorch" in path_lower: + return "pytorch" + elif "api" in path_lower or "http" in path_lower: + return "api" + + return "unknown" + + +def detect_provider(model_path: str) -> Optional[str]: + """Detect who made this variant.""" + path_lower = model_path.lower() + + providers = [ + "thebloke", + "unsloth", + "mlx-community", + "bartowski", + "mradermacher", + "turboderp", + ] + + for provider in providers: + if provider in path_lower: + return provider + + return None + + +def create_model_identity( + model_id: str, + weights_path: Optional[Path] = None, + config: Optional[dict] = None, + parent_root: Optional[str] = None, + behavioral_fingerprint: Optional[BehavioralFingerprint] = None, + **kwargs, +) -> ModelIdentity: + """ + Factory function to create ModelIdentity from various inputs. + + Args: + model_id: HuggingFace model ID or local path + weights_path: Path to weights file (for hashing) + config: Model config dict + parent_root: Merkle root of parent (genesis or base model) + behavioral_fingerprint: Pre-computed fingerprint for APIs + **kwargs: Additional fields (parameters, context_length, etc.) + + Returns: + Finalized ModelIdentity ready for lattice + """ + # Parse base model from full ID + # e.g., "TheBloke/Llama-3-8B-GGUF" -> base is "meta-llama/Llama-3-8B" + base_model = kwargs.pop("base_model", None) + if not base_model: + # Try to extract base from model_id + parts = model_id.split("/") + if len(parts) >= 2: + name = parts[-1] + # Remove common suffixes + for suffix in ["-GGUF", "-GPTQ", "-AWQ", "-fp16", "-bf16", "-GGML"]: + name = name.replace(suffix, "") + base_model = name + else: + base_model = model_id + + # Detect variant info + quantization = detect_quantization(model_id) + format_type = detect_format(model_id) + provider = detect_provider(model_id) + + # Extract bits from quantization + bits = None + if "4" in quantization: + bits = 4 + elif "5" in quantization: + bits = 5 + elif "6" in quantization: + bits = 6 + elif "8" in quantization: + bits = 8 + + variant = ModelVariant( + quantization=quantization, + format=format_type, + bits=bits, + provider=provider, + ) + + # Hash weights if available + weight_hash = None + if weights_path and Path(weights_path).exists(): + # For large files, hash first and last 1MB + size + path = Path(weights_path) + size = path.stat().st_size + hasher = hashlib.sha256() + hasher.update(str(size).encode()) + + with open(path, "rb") as f: + # First 1MB + hasher.update(f.read(1024 * 1024)) + # Last 1MB + if size > 2 * 1024 * 1024: + f.seek(-1024 * 1024, 2) + hasher.update(f.read()) + + weight_hash = hasher.hexdigest()[:16] + + # Hash config if available + config_hash = None + if config: + config_json = json.dumps(config, sort_keys=True) + config_hash = hashlib.sha256(config_json.encode()).hexdigest()[:16] + + # Create identity + identity = ModelIdentity( + base_model=base_model, + model_id="", # Will be computed + variant=variant, + fine_tune=FineTuneInfo(), + weight_hash=weight_hash, + config_hash=config_hash, + behavioral_fingerprint=behavioral_fingerprint, + parent_root=parent_root, + **kwargs, + ) + + # Compute model_id and merkle_root + identity.model_id = identity.compute_model_id() + identity.finalize(parent_root) + + return identity + + +# ============================================================================= +# MODEL REGISTRY (Lattice Integration) +# ============================================================================= + +class ModelRegistry: + """ + Registry of model identities in the lattice. + + Provides: + - Get or create model identity + - Link observations to model identities + - Query models by various criteria + """ + + def __init__(self, lattice_dir: Path = None, genesis_root: str = None): + self.lattice_dir = lattice_dir or Path(__file__).parent.parent / "lattice" + self.models_dir = self.lattice_dir / "models" + self.models_dir.mkdir(parents=True, exist_ok=True) + + # Genesis root (models link to this if no base model) + self.genesis_root = genesis_root or "89f940c1a4b7aa65" + + # Cache of loaded identities + self._cache: Dict[str, ModelIdentity] = {} + self._load_all() + + def _load_all(self): + """Load all model identities from disk.""" + for json_file in self.models_dir.glob("*.json"): + try: + data = json.loads(json_file.read_text()) + identity = self._dict_to_identity(data) + self._cache[identity.merkle_root] = identity + except Exception as e: + print(f"Error loading {json_file}: {e}") + + def _dict_to_identity(self, data: dict) -> ModelIdentity: + """Convert dict back to ModelIdentity.""" + variant_data = data.get("variant", {}) + fine_tune_data = data.get("fine_tune", {}) + fingerprint_data = data.get("behavioral_fingerprint") + + return ModelIdentity( + base_model=data["base_model"], + model_id=data["model_id"], + variant=ModelVariant(**variant_data), + fine_tune=FineTuneInfo(**fine_tune_data), + weight_hash=data.get("weight_hash"), + config_hash=data.get("config_hash"), + tokenizer_hash=data.get("tokenizer_hash"), + behavioral_fingerprint=BehavioralFingerprint(**fingerprint_data) if fingerprint_data else None, + source_url=data.get("source_url"), + source_revision=data.get("source_revision"), + downloaded_at=data.get("downloaded_at"), + parent_root=data.get("parent_root"), + merkle_root=data.get("merkle_root"), + created_at=data.get("created_at", time.time()), + parameters=data.get("parameters"), + context_length=data.get("context_length"), + architecture=data.get("architecture"), + license=data.get("license"), + ) + + def _save_identity(self, identity: ModelIdentity): + """Save identity to disk.""" + filename = f"{identity.merkle_root}.json" + filepath = self.models_dir / filename + filepath.write_text(json.dumps(identity.to_dict(), indent=2)) + + def get_or_create( + self, + model_id: str, + **kwargs, + ) -> ModelIdentity: + """ + Get existing model identity or create new one. + + If model already exists in registry, returns existing. + Otherwise creates new identity linked to genesis or base model. + """ + # Check if we have this model already + for identity in self._cache.values(): + if identity.model_id == model_id or identity.base_model == model_id: + return identity + + # Determine parent + # If this is a variant, try to find base model + parent_root = kwargs.pop("parent_root", None) + if not parent_root: + base = kwargs.get("base_model") + if base: + for identity in self._cache.values(): + if identity.base_model == base and identity.variant.quantization == "none": + parent_root = identity.merkle_root + break + + # Default to genesis + if not parent_root: + parent_root = self.genesis_root + + # Create new identity + identity = create_model_identity( + model_id=model_id, + parent_root=parent_root, + **kwargs, + ) + + # Cache and save + self._cache[identity.merkle_root] = identity + self._save_identity(identity) + + return identity + + def get_by_root(self, merkle_root: str) -> Optional[ModelIdentity]: + """Get model identity by merkle root.""" + return self._cache.get(merkle_root) + + def list_all(self) -> List[ModelIdentity]: + """List all registered models.""" + return list(self._cache.values()) + + def list_by_base(self, base_model: str) -> List[ModelIdentity]: + """List all variants of a base model.""" + return [i for i in self._cache.values() if i.base_model == base_model] + + def search(self, query: str) -> List[ModelIdentity]: + """Search models by name.""" + query_lower = query.lower() + return [ + i for i in self._cache.values() + if query_lower in i.model_id.lower() or query_lower in i.base_model.lower() + ] + + +# ============================================================================= +# CLI +# ============================================================================= + +if __name__ == "__main__": + import sys + + # Test: Create some model identities + print("=== CASCADE Model Identity Layer ===\n") + + # Initialize registry + registry = ModelRegistry() + + # Create some test identities + test_models = [ + "meta-llama/Llama-3-8B", + "TheBloke/Llama-3-8B-GGUF", + "unsloth/Llama-3-8B-bnb-4bit", + "anthropic/claude-3-opus", + "openai/gpt-4", + ] + + for model in test_models: + identity = registry.get_or_create(model) + print(f"Model: {identity.model_id}") + print(f" Base: {identity.base_model}") + print(f" Quant: {identity.variant.quantization}") + print(f" Format: {identity.variant.format}") + print(f" Merkle: {identity.merkle_root}") + print(f" Parent: {identity.parent_root}") + print() + + print(f"Total models in registry: {len(registry.list_all())}") diff --git a/cascade/ipld.py b/cascade/ipld.py new file mode 100644 index 0000000000000000000000000000000000000000..3c745ac521bd7e083a7d6d49d9dcf7ac6895b159 --- /dev/null +++ b/cascade/ipld.py @@ -0,0 +1,379 @@ +""" +CASCADE IPLD - InterPlanetary Linked Data Integration + +Native IPLD encoding for provenance chains. Merkle roots become CIDs. +The lattice goes interplanetary. + +CIDs (Content IDentifiers) are self-describing, content-addressed identifiers. +When we encode a chain as IPLD, its CID is derived from its content. +Anyone with the CID can fetch and verify. + +Architecture: + ProvenanceChain ──encode──► DAG-CBOR ──hash──► CID + │ + bafyreif...xyz (interplanetary address) +""" + +import json +import hashlib +from typing import Dict, Any, Optional, List +from dataclasses import dataclass +from pathlib import Path + +# IPLD encoding +import dag_cbor +from multiformats import CID, multihash + +# CASCADE core +from cascade.core.provenance import ProvenanceChain, ProvenanceRecord + + +# ============================================================================= +# IPLD ENCODING +# ============================================================================= + +def chain_to_ipld(chain: ProvenanceChain) -> Dict[str, Any]: + """ + Convert a ProvenanceChain to IPLD-compatible format. + + IPLD format uses: + - Lowercase keys + - CID links for references + - DAG-CBOR encoding + """ + # Convert records to IPLD format + records = {} + for name, record in chain.records.items(): + records[name] = { + "layer_name": record.layer_name, + "layer_idx": record.layer_idx, + "state_hash": record.state_hash, + "parent_hashes": record.parent_hashes, + "params_hash": record.params_hash, + "shape": record.shape, + "dtype": record.dtype, + "stats": record.stats, + "execution_order": record.execution_order, + "timestamp": record.timestamp, + } + + # Convert external_roots to CID links if they look like CIDs + external_links = [] + for root in chain.external_roots: + if root.startswith("bafy") or root.startswith("Qm"): + # Already a CID - create a link + external_links.append({"/": root}) + else: + # Legacy merkle root - keep as string + external_links.append({"legacy_root": root}) + + return { + "session_id": chain.session_id, + "model_id": chain.model_id, + "model_hash": chain.model_hash, + "input_hash": chain.input_hash, + "output_hash": chain.output_hash, + "records": records, + "external_roots": chain.external_roots, # Keep for verification + "external_links": external_links, # IPLD links + "merkle_root": chain.merkle_root, + "created_at": chain.created_at, + "finalized": chain.finalized, + "ipld_version": 1, + } + + +def encode_to_dag_cbor(data: Dict[str, Any]) -> bytes: + """Encode data as DAG-CBOR (canonical CBOR for IPLD).""" + return dag_cbor.encode(data) + + +def decode_from_dag_cbor(raw: bytes) -> Dict[str, Any]: + """Decode DAG-CBOR data.""" + return dag_cbor.decode(raw) + + +def compute_cid(data: bytes, codec: str = "dag-cbor") -> str: + """ + Compute CID (Content IDentifier) from data. + + CID = multicodec(codec) + multihash(sha256(data)) + + Returns CIDv1 in base32 (bafyrei...) + """ + # SHA-256 hash of the data + digest = hashlib.sha256(data).digest() + + # Create multihash (0x12 = sha2-256, 0x20 = 32 bytes) + mh = multihash.wrap(digest, "sha2-256") + + # Create CID v1 with dag-cbor codec (0x71) + cid = CID("base32", 1, "dag-cbor", mh) + + return str(cid) + + +def chain_to_cid(chain: ProvenanceChain) -> tuple[str, bytes]: + """ + Convert chain to CID. + + Returns: + (cid_string, encoded_bytes) + """ + ipld_data = chain_to_ipld(chain) + encoded = encode_to_dag_cbor(ipld_data) + cid = compute_cid(encoded) + return cid, encoded + + +# ============================================================================= +# IPLD CHAIN - Native CID-based chain +# ============================================================================= + +@dataclass +class IPLDChain: + """ + A provenance chain with native CID support. + + Instead of custom merkle roots, uses CIDs. + Links to other chains via CID references. + """ + chain: ProvenanceChain + cid: Optional[str] = None + encoded: Optional[bytes] = None + + @classmethod + def from_chain(cls, chain: ProvenanceChain) -> 'IPLDChain': + """Create IPLD chain from regular chain.""" + cid, encoded = chain_to_cid(chain) + return cls(chain=chain, cid=cid, encoded=encoded) + + @classmethod + def from_bytes(cls, data: bytes) -> 'IPLDChain': + """Deserialize from DAG-CBOR bytes.""" + ipld_data = decode_from_dag_cbor(data) + chain = ipld_to_chain(ipld_data) + cid = compute_cid(data) + return cls(chain=chain, cid=cid, encoded=data) + + def link_to(self, other: 'IPLDChain') -> None: + """Link this chain to another via CID.""" + if other.cid is None: + raise ValueError("Cannot link to chain without CID") + self.chain.link_external(other.cid, source_id=other.chain.model_id) + # Recompute our CID since we changed + self.cid, self.encoded = chain_to_cid(self.chain) + + def save(self, path: Path) -> None: + """Save as DAG-CBOR file.""" + if self.encoded is None: + self.cid, self.encoded = chain_to_cid(self.chain) + with open(path, 'wb') as f: + f.write(self.encoded) + + @classmethod + def load(cls, path: Path) -> 'IPLDChain': + """Load from DAG-CBOR file.""" + with open(path, 'rb') as f: + data = f.read() + return cls.from_bytes(data) + + def to_json(self) -> str: + """Export as JSON (for human inspection).""" + ipld_data = chain_to_ipld(self.chain) + ipld_data["_cid"] = self.cid + return json.dumps(ipld_data, indent=2, default=str) + + +def ipld_to_chain(ipld_data: Dict[str, Any]) -> ProvenanceChain: + """Convert IPLD data back to ProvenanceChain.""" + # Reconstruct records + records = {} + for name, rec_data in ipld_data.get("records", {}).items(): + records[name] = ProvenanceRecord( + layer_name=rec_data["layer_name"], + layer_idx=rec_data["layer_idx"], + state_hash=rec_data["state_hash"], + parent_hashes=rec_data["parent_hashes"], + params_hash=rec_data.get("params_hash"), + shape=rec_data.get("shape", []), + dtype=rec_data.get("dtype", "float32"), + stats=rec_data.get("stats", {}), + execution_order=rec_data.get("execution_order", 0), + timestamp=rec_data.get("timestamp", 0), + ) + + chain = ProvenanceChain( + session_id=ipld_data["session_id"], + model_id=ipld_data["model_id"], + model_hash=ipld_data["model_hash"], + input_hash=ipld_data["input_hash"], + output_hash=ipld_data.get("output_hash"), + external_roots=ipld_data.get("external_roots", []), + merkle_root=ipld_data.get("merkle_root"), + created_at=ipld_data.get("created_at", 0), + finalized=ipld_data.get("finalized", False), + ) + chain.records = records + + return chain + + +# ============================================================================= +# IPFS PUBLISHING (requires running IPFS daemon) +# ============================================================================= + +def publish_to_ipfs(chain: IPLDChain, ipfs_api: str = "/ip4/127.0.0.1/tcp/5001") -> str: + """ + Publish chain to IPFS network. + + Requires IPFS daemon running locally. + Returns the CID (which should match our computed CID). + + Args: + chain: IPLDChain to publish + ipfs_api: IPFS API multiaddr + + Returns: + CID from IPFS (for verification) + """ + try: + import ipfshttpclient + client = ipfshttpclient.connect(ipfs_api) + + # Add the raw DAG-CBOR data + result = client.dag.put( + chain.encoded, + store_codec="dag-cbor", + input_codec="dag-cbor" + ) + + ipfs_cid = result["Cid"]["/"] + + # Verify CIDs match + if ipfs_cid != chain.cid: + print(f"[WARN] CID mismatch: computed={chain.cid}, ipfs={ipfs_cid}") + + return ipfs_cid + + except Exception as e: + print(f"[ERROR] IPFS publish failed: {e}") + print(" Make sure IPFS daemon is running: ipfs daemon") + raise + + +def fetch_from_ipfs(cid: str, ipfs_api: str = "/ip4/127.0.0.1/tcp/5001") -> IPLDChain: + """ + Fetch chain from IPFS network by CID. + + Args: + cid: Content identifier + ipfs_api: IPFS API multiaddr + + Returns: + IPLDChain + """ + try: + import ipfshttpclient + client = ipfshttpclient.connect(ipfs_api) + + # Get the DAG node + data = client.dag.get(cid) + + # Convert to chain + chain = ipld_to_chain(data) + encoded = encode_to_dag_cbor(data) + + return IPLDChain(chain=chain, cid=cid, encoded=encoded) + + except Exception as e: + print(f"[ERROR] IPFS fetch failed: {e}") + raise + + +# ============================================================================= +# GENESIS IN IPLD +# ============================================================================= + +def get_genesis_cid() -> tuple[str, IPLDChain]: + """ + Get genesis as IPLD chain with CID. + + The genesis CID is deterministic - anyone computing it gets the same result. + This is the interplanetary Schelling point. + """ + from cascade.genesis import create_genesis + + genesis = create_genesis() + ipld_genesis = IPLDChain.from_chain(genesis) + + return ipld_genesis.cid, ipld_genesis + + +# ============================================================================= +# CLI +# ============================================================================= + +if __name__ == "__main__": + import sys + + print("=" * 60) + print("CASCADE IPLD - InterPlanetary Linked Data") + print("=" * 60) + + # Get genesis CID + genesis_cid, genesis_ipld = get_genesis_cid() + print(f"\nGenesis CID: {genesis_cid}") + print(f"Genesis merkle_root: {genesis_ipld.chain.merkle_root}") + + # Load cascade_alpha and convert to IPLD + alpha_path = Path("lattice/cascade_alpha.json") + if alpha_path.exists(): + with open(alpha_path) as f: + alpha_data = json.load(f) + alpha_chain = ProvenanceChain.from_dict(alpha_data) + alpha_ipld = IPLDChain.from_chain(alpha_chain) + + print(f"\ncascade_alpha CID: {alpha_ipld.cid}") + print(f"cascade_alpha merkle_root: {alpha_chain.merkle_root}") + + # Save as DAG-CBOR + out_dir = Path("lattice/ipld") + out_dir.mkdir(exist_ok=True) + + genesis_ipld.save(out_dir / "genesis.cbor") + alpha_ipld.save(out_dir / "cascade_alpha.cbor") + + # Also save JSON for inspection + with open(out_dir / "genesis.ipld.json", 'w') as f: + f.write(genesis_ipld.to_json()) + with open(out_dir / "cascade_alpha.ipld.json", 'w') as f: + f.write(alpha_ipld.to_json()) + + print(f"\nSaved to {out_dir}/") + print(f" - genesis.cbor") + print(f" - cascade_alpha.cbor") + print(f" - genesis.ipld.json") + print(f" - cascade_alpha.ipld.json") + + print("\n" + "=" * 60) + print("INTERPLANETARY ADDRESSES") + print("=" * 60) + print(f""" +Genesis: {genesis_cid} +cascade_alpha: {alpha_ipld.cid if alpha_path.exists() else 'N/A'} + +These CIDs are content-addressed. Anyone with the CID can: +1. Fetch the data from IPFS (if pinned) +2. Verify the content matches the CID +3. Trust the chain without trusting the source + +To publish to IPFS: + ipfs daemon # Start IPFS + python -c " + from cascade.ipld import publish_to_ipfs, get_genesis_cid + _, genesis = get_genesis_cid() + cid = publish_to_ipfs(genesis) + print(f'Published: {{cid}}') + " + """) diff --git a/cascade/listen.py b/cascade/listen.py new file mode 100644 index 0000000000000000000000000000000000000000..db8393839d448c9ede6d1f2302452dad8ca63952 --- /dev/null +++ b/cascade/listen.py @@ -0,0 +1,154 @@ +""" +Cascade Passive Monitor. + +Listens to stdin or follows a log file and observes events. + +Usage: + python -m cascade.listen # Listen to stdin + python -m cascade.listen --follow app.log # Follow a log file + +This module: +1. Reads input from stdin or a log file +2. Pipes lines -> Cascade Adapter +3. Writes events to tape file (JSONL) and human log (Markdown) +4. Emits events to event_queue for external consumers + +For visualization, point a consumer at the event_queue or load the tape file +into your preferred visualization tool. +""" + +import sys +import argparse +import time +import json +from pathlib import Path +from queue import Queue + +# Ensure package root is in path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from cascade import Monitor + +# Shared event queue for external consumers (e.g., custom UIs) +event_queue: Queue = Queue() + + +def main(): + parser = argparse.ArgumentParser(description="Cascade Passive Monitor") + parser.add_argument("--log-dir", default="./logs", help="Directory for logs") + parser.add_argument("--follow", help="Log file to follow (tail -f style)") + parser.add_argument("--quiet", "-q", action="store_true", help="Suppress console output") + args = parser.parse_args() + + # 0. Setup Logs & Baggies + log_dir = Path(args.log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + + baggies_dir = log_dir / "baggies" + baggies_dir.mkdir(exist_ok=True) + + # Excrement Management (Archive old artifacts) + follow_abs = Path(args.follow).absolute() if args.follow else None + for f in log_dir.glob("*.*"): + if f.is_file() and f.suffix in [".md", ".jsonl", ".log"] and "baggies" not in str(f): + if follow_abs and f.absolute() == follow_abs: + continue + try: + dest = baggies_dir / f.name + if dest.exists(): + dest = baggies_dir / f"{f.stem}_{int(time.time())}{f.suffix}" + f.replace(dest) + except Exception: + pass + print(f"[CASCADE] Logs archived to {baggies_dir}") + + session_id = int(time.time()) + tape_path = log_dir / f"cascade_tape_{session_id}.jsonl" + human_path = log_dir / f"cascade_log_{session_id}.md" + + tape_file = open(tape_path, "w", encoding="utf-8") + human_file = open(human_path, "w", encoding="utf-8") + + # Init Log + human_file.write(f"# CASCADE MISSION LOG // SESSION {session_id}\n") + human_file.write(f"**Mode:** PASSIVE {'FOLLOWER' if args.follow else 'LISTENER'}\n") + human_file.write(f"**Target:** `{args.follow or 'STDIN'}`\n---\n\n") + human_file.flush() + + print("="*60) + print("CASCADE // LISTENER") + print(f"Monitoring: {args.follow if args.follow else 'Standard Input'}") + print(f"Tape: {tape_path.absolute()}") + print(f"Baggies: {baggies_dir.absolute()}") + print("="*60) + + monitor = Monitor("symbiont_passive") + + def process_line(line): + line = line.strip() + if not line: + return + event = monitor.observe(line) + payload = { + "event": { + "event_id": event.event_id, + "timestamp": event.timestamp, + "component": event.component, + "event_type": event.event_type, + "data": event.data, + "raw": line, # Include original line for drill-down + }, + "metrics": monitor.metrics.summary(), + "triage": monitor.metrics.triage(), + } + event_queue.put(payload) + tape_file.write(json.dumps(payload) + "\n") + tape_file.flush() + + # Narrative + t_str = time.strftime('%H:%M:%S', time.localtime(event.timestamp)) + icon = {"error": "🔴", "warning": "⚠️", "state_change": "🔄"}.get(event.event_type, "ℹ️") + if "loss" in str(event.data): + icon = "📉" + human_file.write(f"### {icon} {t_str} // {event.event_type.upper()}\n") + human_file.write(f"Event observed in **{event.component}**.\n") + if event.data: + human_file.write("```yaml\n") + for k, v in event.data.items(): + human_file.write(f"{k}: {v}\n") + human_file.write("```\n") + human_file.write("\n") + human_file.flush() + + # Mirror to console (unless quiet) + if not args.quiet: + sys.stdout.write(f"[SIGHT] {line[:80]}...\n") + sys.stdout.flush() + + try: + if args.follow: + print(f"[CASCADE] Waiting for stream: {args.follow}") + f_path = Path(args.follow) + if not f_path.exists(): + f_path.touch() + with open(f_path, "r", encoding="utf-8", errors="replace") as f: + print(f"[CASCADE] Scanning for events...") + while True: + line = f.readline() + if not line: + time.sleep(0.1) + continue + process_line(line) + else: + print("[CASCADE] Reading from stdin (Ctrl+C to stop)...") + for line in sys.stdin: + process_line(line) + except KeyboardInterrupt: + print("\n[CASCADE] Detaching...") + finally: + tape_file.close() + human_file.close() + print(f"[CASCADE] Session complete. Tape: {tape_path}") + +if __name__ == "__main__": + main() diff --git a/cascade/logging/__init__.py b/cascade/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7389e8d95c2119262af98c3b32e38ae1ad365c2f --- /dev/null +++ b/cascade/logging/__init__.py @@ -0,0 +1,86 @@ +""" +CASCADE Logging System +Industry-standard dual-layer logging for mathematical precision and human clarity. + +Two modes: +1. Kleene Mode: Mathematical fixed point logs for debugging and verification +2. Interpretive Mode: Human-readable causation stories for operators + +Use together for complete system observability. +""" + +from .kleene_logger import ( + KleeneLogger, + LogLevel, + get_kleene_logger, + log_fixed_point, + log_iterations +) + +from .interpretive_logger import ( + InterpretiveLogger, + ImpactLevel, + get_interpretive_logger, + translate_kleene_to_interpretive +) + +from .log_manager import ( + LogMode, + LogConfig, + CascadeLogManager, + init_logging, + get_log_manager, + log +) + + +def init_cascade_logging(component: str, system: str): + """Initialize both logging layers for a component""" + kleene = get_kleene_logger(component) + interpretive = get_interpretive_logger(system) + + # Bridge automatic translation + def bridge_log(entry): + translate_kleene_to_interpretive(entry, interpretive) + + kleene._emit_to_container = lambda entry: ( + print(kleene._format_container(entry)), + bridge_log(entry) + ) + + return kleene, interpretive + + +# Convenience for quick setup +def setup_logging(component: str, system: str = "CASCADE"): + """Quick setup for both loggers""" + return init_cascade_logging(component, system) + + +# Export main interfaces +__all__ = [ + # Kleene (mathematical) + 'KleeneLogger', + 'LogLevel', + 'get_kleene_logger', + 'log_fixed_point', + 'log_iterations', + + # Interpretive (human) + 'InterpretiveLogger', + 'ImpactLevel', + 'get_interpretive_logger', + 'translate_kleene_to_interpretive', + + # Log Manager (orchestrator) + 'LogMode', + 'LogConfig', + 'CascadeLogManager', + 'init_logging', + 'get_log_manager', + 'log', + + # Unified + 'init_cascade_logging', + 'setup_logging' +] diff --git a/cascade/logging/color_example.py b/cascade/logging/color_example.py new file mode 100644 index 0000000000000000000000000000000000000000..48d63fdd57c3c048add9037bc2df957c8edfdd98 --- /dev/null +++ b/cascade/logging/color_example.py @@ -0,0 +1,107 @@ +""" +CASCADE Color Logging Example +Shows how to integrate beautiful colored logs throughout your system. +""" + +from .kleene_logger import get_kleene_logger, LogLevel +from .interpretive_logger import get_interpretive_logger, ImpactLevel + +def example_data_processing(): + """Example: Data processing with beautiful logs""" + kleene = get_kleene_logger("DataProcessor") + interpretive = get_interpretive_logger("Data Pipeline") + + # Start processing + kleene.log(LogLevel.INFO, "load_dataset_start", + state_before={"dataset": "smollm3-blueprint.pdf"}) + + interpretive.log(ImpactLevel.LOW, "DataLoader", "Loading dataset", + context="Reading PDF file for analysis", + consequence="Will extract text and metadata", + metrics={"file_size": "1.0MB", "type": "PDF"}) + + # Processing steps + kleene.log(LogLevel.DEBUG, "extract_text", + state_before={"page": 1}, + state_after={"pages_processed": 15}) + + # Fixed point reached + kleene.log(LogLevel.INFO, "processing_complete", + state_after={"records": 500, "clean": True}, + fixed_point=True, + iterations=3) + + interpretive.log(ImpactLevel.MEDIUM, "DataProcessor", "Processing complete", + context="Successfully extracted and cleaned data", + consequence="Ready for forensics analysis", + metrics={"records": 500, "pages": 15, "errors": 0}) + +def example_model_observation(): + """Example: Model observation with beautiful logs""" + kleene = get_kleene_logger("ModelObserver") + interpretive = get_interpretive_logger("Model Observatory") + + # Model loading + kleene.log(LogLevel.INFO, "model_load_start", + state_before={"model": "mistralai/Mixtral-8x22B-Instruct-v0.1"}) + + interpretive.log(ImpactLevel.MEDIUM, "ModelLoader", "Loading Mixtral", + context="Loading 8x22B MoE model for inference", + consequence="Will consume significant VRAM", + metrics={"params": "141B", "active": "39B", "device": "cuda"}) + + # Observation + kleene.log(LogLevel.INFO, "observation_start", + state_before={"layers": 0, "hash": "initial"}) + + # Fixed point achieved + kleene.log(LogLevel.INFO, "observation_fixed_point", + state_after={"layers": 64, "merkle": "abc123..."}, + fixed_point=True, + iterations=64) + + interpretive.log(ImpactLevel.LOW, "CASCADE", "Model observed", + context="Cryptographic proof generated for model execution", + consequence="Merkle root provides verifiable audit trail", + metrics={"model": "Mixtral", "layers": 64, "merkle": "abc123..."}) + +def example_error_handling(): + """Example: Error handling with colored logs""" + kleene = get_kleene_logger("ErrorHandler") + interpretive = get_interpretive_logger("System Monitor") + + # Error detected + kleene.log(LogLevel.ERROR, "memory_exhaustion", + state_before={"memory": "15.8/16GB", "operation": "inference"}, + fixed_point=False) + + interpretive.log(ImpactLevel.HIGH, "MemoryManager", "Out of memory", + context="GPU memory exhausted during model inference", + consequence="Inference failed, system degraded", + metrics={"used": "15.8GB", "total": "16GB", "available": "200MB"}, + recommendation="Enable gradient checkpointing or use smaller batch size") + + # Recovery + kleene.log(LogLevel.WARNING, "fallback_activated", + state_after={"mode": "cpu_fallback", "batch_size": 1}) + + interpretive.log(ImpactLevel.MEDIUM, "FallbackHandler", "CPU fallback activated", + context="Switched to CPU inference due to memory constraints", + consequence="Performance degraded but functionality preserved", + metrics={"device": "cpu", "batch_size": 1, "slowdown": "10x"}) + +# Run all examples +if __name__ == "__main__": + print("\n🎨 CASCADE Color Logging Examples\n") + print("="*60) + + example_data_processing() + print("\n" + "="*60) + + example_model_observation() + print("\n" + "="*60) + + example_error_handling() + print("\n" + "="*60) + + print("\n✨ Beautiful logs are ready for production!") diff --git a/cascade/logging/integrate.py b/cascade/logging/integrate.py new file mode 100644 index 0000000000000000000000000000000000000000..3094549dcd7d2df4752c38caae85fda7812aa643 --- /dev/null +++ b/cascade/logging/integrate.py @@ -0,0 +1,275 @@ +""" +CASCADE Logging Integration +Plug-and-play logging for existing CASCADE components. + +Retrofits existing systems with world-class logging without major surgery. +""" + +import functools +import time +from typing import Any, Callable, Dict, Optional + +from .log_manager import get_log_manager, LogLevel, ImpactLevel + + +def log_component(component_name: str, system: str = "CASCADE"): + """Decorator to add logging to any class or function""" + def decorator(target): + if isinstance(target, type): + # Decorating a class + return _log_class(target, component_name, system) + else: + # Decorating a function + return _log_function(target, component_name, system) + return decorator + + +def _log_class(cls, component_name: str, system: str): + """Add logging to all methods of a class""" + manager = get_log_manager() + manager.register_component(component_name, system) + + for attr_name in dir(cls): + if not attr_name.startswith('_'): + attr = getattr(cls, attr_name) + if callable(attr): + setattr(cls, attr_name, _log_method(attr, component_name)) + + return cls + + +def _log_function(func, component_name: str, system: str): + """Add logging to a function""" + manager = get_log_manager() + manager.register_component(component_name, system) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + + # Log start + get_log_manager().log_operation( + component_name, f"{func.__name__}_start", + level=LogLevel.DEBUG, + impact=ImpactLevel.TRACE, + details={ + "context": f"Starting {func.__name__}", + "consequence": f"Will execute {func.__name__}", + "metrics": {"args": len(args), "kwargs": len(kwargs)} + } + ) + + try: + result = func(*args, **kwargs) + + # Log success + duration = time.time() - start_time + get_log_manager().log_operation( + component_name, f"{func.__name__}_complete", + level=LogLevel.INFO, + impact=ImpactLevel.LOW, + details={ + "context": f"Completed {func.__name__}", + "consequence": f"Result ready", + "metrics": {"duration_seconds": duration} + } + ) + + return result + + except Exception as e: + # Log error + get_log_manager().log_operation( + component_name, f"{func.__name__}_error", + level=LogLevel.ERROR, + impact=ImpactLevel.HIGH, + details={ + "context": f"Failed in {func.__name__}", + "consequence": "Operation failed", + "metrics": {"error": str(e)} + } + ) + raise + + return wrapper + + +def _log_method(method, component_name: str): + """Add logging to a method""" + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + start_time = time.time() + + try: + result = method(self, *args, **kwargs) + + # Log successful method call + get_log_manager().log_operation( + component_name, f"{method.__name__}", + level=LogLevel.DEBUG, + impact=ImpactLevel.TRACE, + details={ + "metrics": {"duration": time.time() - start_time} + } + ) + + return result + + except Exception as e: + # Log method error + get_log_manager().log_operation( + component_name, f"{method.__name__}_error", + level=LogLevel.ERROR, + impact=ImpactLevel.HIGH, + details={ + "context": f"Method {method.__name__} failed", + "metrics": {"error": str(e)} + } + ) + raise + + return wrapper + + +def log_kleene_iterations(operation_name: str): + """Decorator specifically for Kleene fixed point iterations""" + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + get_log_manager().log_operation( + "KleeneEngine", f"{operation_name}_start", + level=LogLevel.INFO, + impact=ImpactLevel.MEDIUM, + details={ + "context": f"Starting fixed point iteration for {operation_name}", + "consequence": "Will iterate until convergence" + } + ) + + start_time = time.time() + result = func(*args, **kwargs) + + # Extract iteration info from result if available + iterations = getattr(result, 'iterations', 0) + converged = getattr(result, 'converged', True) + + get_log_manager().log_operation( + "KleeneEngine", f"{operation_name}_complete", + level=LogLevel.INFO, + impact=ImpactLevel.LOW if converged else ImpactLevel.HIGH, + details={ + "context": f"Fixed point iteration {'converged' if converged else 'diverged'}", + "consequence": f"Processed {iterations} iterations", + "metrics": { + "iterations": iterations, + "converged": converged, + "duration": time.time() - start_time + }, + "fixed_point": converged + } + ) + + return result + return wrapper + return decorator + + +def log_model_observation(model_id: str): + """Decorator for model observation functions""" + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + get_log_manager().log_operation( + "ModelObserver", f"observe_{model_id}", + level=LogLevel.INFO, + impact=ImpactLevel.MEDIUM, + details={ + "context": f"Starting observation of model {model_id}", + "consequence": "Will generate cryptographic proof" + } + ) + + result = func(*args, **kwargs) + + # Extract observation details + layers = getattr(result, 'layer_count', 0) + merkle = getattr(result, 'merkle_root', 'unknown') + + get_log_manager().log_operation( + "ModelObserver", f"observed_{model_id}", + level=LogLevel.INFO, + impact=ImpactLevel.LOW, + details={ + "context": f"Model observation complete", + "consequence": "Cryptographic proof generated", + "metrics": { + "model": model_id, + "layers": layers, + "merkle": merkle[:16] + "..." + }, + "fixed_point": True + } + ) + + return result + return wrapper + return decorator + + +def log_data_processing(dataset_name: str): + """Decorator for data processing functions""" + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + get_log_manager().log_operation( + "DataProcessor", f"process_{dataset_name}", + level=LogLevel.INFO, + impact=ImpactLevel.MEDIUM, + details={ + "context": f"Processing dataset {dataset_name}", + "consequence": "Will extract and analyze data" + } + ) + + result = func(*args, **kwargs) + + # Extract processing stats + records = getattr(result, 'record_count', 0) + operations = getattr(result, 'operations', []) + + get_log_manager().log_operation( + "DataProcessor", f"processed_{dataset_name}", + level=LogLevel.INFO, + impact=ImpactLevel.LOW, + details={ + "context": f"Dataset processing complete", + "consequence": f"Processed {records} records", + "metrics": { + "dataset": dataset_name, + "records": records, + "operations": len(operations) + } + } + ) + + return result + return wrapper + return decorator + + +# Quick integration function +def integrate_cascade_logging(): + """One-call integration for entire CASCADE system""" + from ..system.observer import SystemObserver + from ..core.provenance import ProvenanceTracker + from data_unity import run_kleene_iteration + + # Register main components + manager = get_log_manager() + manager.register_component("SystemObserver", "System Observatory") + manager.register_component("ProvenanceTracker", "Model Observatory") + manager.register_component("DataUnity", "Data Unity") + manager.register_component("KleeneEngine", "NEXUS") + + print("✅ CASCADE logging integrated across all components") + return manager diff --git a/cascade/logging/interpretive_logger.py b/cascade/logging/interpretive_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..fedae16231a27b939586f5f0ccefa2a1888d65e1 --- /dev/null +++ b/cascade/logging/interpretive_logger.py @@ -0,0 +1,276 @@ +""" +CASCADE Interpretive Logger +Human-readable causation flow logging for operators and stakeholders. + +Translates mathematical events into stories humans can understand and act upon. +""" + +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional +from datetime import datetime + + +class ImpactLevel(Enum): + """Business impact levels""" + CRITICAL = "🔴 CRITICAL" # Service down, data loss + HIGH = "🟠 HIGH" # Degraded performance, user impact + MEDIUM = "🟡 MEDIUM" # Issues detected, monitoring needed + LOW = "🟢 LOW" # Informational, routine operations + TRACE = "🔵 TRACE" # Detailed flow, debugging + + +@dataclass +class InterpretiveEntry: + """A human-readable system event""" + timestamp: float = field(default_factory=time.time) + impact: ImpactLevel = ImpactLevel.LOW + system: str = "" # High-level system name + component: str = "" # Specific component + event: str = "" # What happened + context: str = "" # Why it matters + consequence: str = "" # What happens next + metrics: Dict[str, Any] = field(default_factory=dict) + recommendation: Optional[str] = None + + def format_display(self) -> str: + """Format for beautiful terminal output with colors""" + time_str = datetime.fromtimestamp(self.timestamp).strftime("%H:%M:%S") + + # ANSI color codes + colors = { + "CRITICAL": ("\033[91m", "🔴"), # Bright red + "HIGH": ("\033[31m", "🟠"), # Red + "MEDIUM": ("\033[33m", "🟡"), # Yellow + "LOW": ("\033[32m", "🟢"), # Green + "TRACE": ("\033[90m", "🔵"), # Gray + "RESET": "\033[0m", + "BOLD": "\033[1m", + "DIM": "\033[2m", + "CYAN": "\033[36m", + "MAGENTA": "\033[35m", + } + + color, icon = colors.get(self.impact.value, ("\033[0m", "⚪")) + reset = colors["RESET"] + bold = colors["BOLD"] + dim = colors["DIM"] + cyan = colors["CYAN"] + magenta = colors["MAGENTA"] + + lines = [ + f"\n{color}{bold}{icon} {self.impact.value} [{time_str}] {self.system}{reset}", + f"├─ {cyan}Component:{reset} {self.component}", + f"├─ {magenta}Event:{reset} {self.event}", + f"├─ {dim}Context:{reset} {self.context}", + f"├─ {dim}Consequence:{reset} {self.consequence}", + ] + + if self.metrics: + lines.append(f"├─ {cyan}Metrics:{reset} {self._format_metrics()}") + + if self.recommendation: + lines.append(f"└─ {bold}Recommendation:{reset} {self.recommendation}") + else: + lines.append(f"└─ {dim}Status: Monitoring{reset}") + + return "\n".join(lines) + + def _format_metrics(self) -> str: + """Format metrics nicely""" + return ", ".join([f"{k}={v}" for k, v in self.metrics.items()]) + + +class InterpretiveLogger: + """Human-readable system storytelling""" + + def __init__(self, system_name: str): + self.system = system_name + self.entries: List[InterpretiveEntry] = [] + self.start_time = time.time() + + def log(self, impact: ImpactLevel, component: str, event: str, + context: str, consequence: str, + metrics: Optional[Dict] = None, + recommendation: Optional[str] = None): + """Record a system event""" + + entry = InterpretiveEntry( + impact=impact, + system=self.system, + component=component, + event=event, + context=context, + consequence=consequence, + metrics=metrics or {}, + recommendation=recommendation + ) + + self.entries.append(entry) + self._emit_to_container(entry) + + def _emit_to_container(self, entry: InterpretiveEntry): + """Emit beautiful formatted log to container""" + print(entry.format_display()) + + # Convenience methods for common events + def service_start(self, component: str, port: int = None): + """Service started successfully""" + self.log( + ImpactLevel.LOW, + component, + "Service started", + f"Component initialized and ready for requests", + f"Accepting connections on port {port}" if port else "Ready for operations", + metrics={"port": port} if port else {}, + recommendation="Monitor for healthy connections" + ) + + def service_error(self, component: str, error: str, impact: ImpactLevel = ImpactLevel.HIGH): + """Service encountered error""" + self.log( + impact, + component, + "Service error", + f"Component failed to process request", + f"May affect system reliability", + metrics={"error": error}, + recommendation="Check component logs and restart if needed" + ) + + def data_processing(self, dataset: str, records: int, operations: List[str]): + """Data processing pipeline""" + self.log( + ImpactLevel.MEDIUM, + "DataProcessor", + f"Processing {dataset}", + f"Executing pipeline operations on dataset", + f"Will process {records:,} records through {len(operations)} stages", + metrics={ + "dataset": dataset, + "records": records, + "operations": len(operations) + }, + recommendation="Monitor processing progress and error rates" + ) + + def model_loaded(self, model_id: str, size_gb: float, device: str): + """AI model loaded into memory""" + self.log( + ImpactLevel.MEDIUM, + "ModelLoader", + f"Model {model_id} loaded", + f"Neural network loaded and ready for inference", + f"Consuming {size_gb:.1f}GB VRAM on {device}", + metrics={ + "model": model_id, + "size_gb": size_gb, + "device": device + }, + recommendation="Monitor GPU memory usage during inference" + ) + + def security_event(self, component: str, event: str, details: str): + """Security-related event""" + self.log( + ImpactLevel.CRITICAL, + component, + f"Security: {event}", + f"Security system detected potential threat", + f"Immediate investigation required", + metrics={"details": details}, + recommendation="Review security logs and consider blocking source" + ) + + def performance_warning(self, component: str, metric: str, value: float, threshold: float): + """Performance threshold exceeded""" + self.log( + ImpactLevel.HIGH, + component, + f"Performance warning: {metric}", + f"Component performance degraded", + f"May impact user experience if continues", + metrics={metric: value, "threshold": threshold}, + recommendation=f"Optimize {metric} or scale resources" + ) + + def cascade_observation(self, model: str, layers: int, merkle_root: str): + """CASCADE observed model execution""" + self.log( + ImpactLevel.INFO, + "CASCADE", + f"Model observation complete", + f"Cryptographic proof generated for model execution", + f"Merkle root provides verifiable audit trail", + metrics={ + "model": model, + "layers": layers, + "merkle": merkle_root[:16] + "..." + }, + recommendation="Store attestation for permanent records" + ) + + def fixed_point_convergence(self, operation: str, iterations: int, entities: int): + """Mathematical fixed point reached""" + self.log( + ImpactLevel.INFO, + "KleeneEngine", + f"Fixed point convergence", + f"{operation} completed after {iterations} iterations", + f"Resolved relationships for {entities} entities", + metrics={ + "operation": operation, + "iterations": iterations, + "entities": entities + }, + recommendation="Review convergence quality metrics" + ) + + +# Global interpretive loggers +_interpretive_loggers: Dict[str, InterpretiveLogger] = {} + + +def get_interpretive_logger(system: str) -> InterpretiveLogger: + """Get or create interpretive logger for system""" + if system not in _interpretive_loggers: + _interpretive_loggers[system] = InterpretiveLogger(system) + return _interpretive_loggers[system] + + +# Bridge function to translate Kleene logs to interpretive +def translate_kleene_to_interpretive(kleene_entry, interpretive_logger): + """Translate mathematical log to human story""" + + # Map Kleene levels to impact levels + impact_map = { + "CRITICAL": ImpactLevel.CRITICAL, + "ERROR": ImpactLevel.HIGH, + "WARNING": ImpactLevel.MEDIUM, + "INFO": ImpactLevel.LOW, + "DEBUG": ImpactLevel.TRACE, + "TRACE": ImpactLevel.TRACE + } + + # Create human-readable context + if kleene_entry.fixed_point_reached: + event = f"Mathematical convergence achieved" + context = f"Operation {kleene_entry.operation} reached stable state" + consequence = "System can proceed with verified result" + else: + event = f"State transition in {kleene_entry.operation}" + context = f"Component processing through iterations" + consequence = "Continuing toward fixed point" + + interpretive_logger.log( + impact_map.get(kleene_entry.level.value, ImpactLevel.LOW), + kleene_entry.component, + event, + context, + consequence, + metrics={ + "iterations": kleene_entry.iteration_count, + "hash": kleene_entry.hash_value + } + ) diff --git a/cascade/logging/kleene_logger.py b/cascade/logging/kleene_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..711b2d374db15ffff4ec55dfde99f31ee1659de5 --- /dev/null +++ b/cascade/logging/kleene_logger.py @@ -0,0 +1,219 @@ +""" +CASCADE Kleene Fixed Point Logger +Industry-standard mathematical logging for debugging and verification. + +Each log entry is a fixed point observation - hashable, verifiable, complete. +""" + +import hashlib +import json +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional +from contextlib import contextmanager + + +class LogLevel(Enum): + """Mathematical significance levels""" + CRITICAL = "CRITICAL" # System-breaking fixed point failure + ERROR = "ERROR" # Fixed point not reached + WARNING = "WARNING" # Unexpected state transition + INFO = "INFO" # Fixed point achieved + DEBUG = "DEBUG" # State transition details + TRACE = "TRACE" # Every computation step + + +@dataclass +class KleeneLogEntry: + """A single fixed point observation""" + timestamp: float = field(default_factory=time.time) + level: LogLevel = LogLevel.INFO + component: str = "" + operation: str = "" + state_before: Optional[Dict] = None + state_after: Optional[Dict] = None + fixed_point_reached: bool = False + iteration_count: int = 0 + hash_value: str = field(init=False) + + def __post_init__(self): + # Create content hash for verifiability + content = { + "timestamp": self.timestamp, + "component": self.component, + "operation": self.operation, + "state_before": self.state_before, + "state_after": self.state_after, + "iteration": self.iteration_count + } + self.hash_value = hashlib.sha256( + json.dumps(content, sort_keys=True).encode() + ).hexdigest()[:16] + + def to_dict(self) -> Dict[str, Any]: + return { + "ts": self.timestamp, + "lvl": self.level.value, + "comp": self.component, + "op": self.operation, + "before": self.state_before, + "after": self.state_after, + "fixed": self.fixed_point_reached, + "iter": self.iteration_count, + "hash": self.hash_value + } + + +class KleeneLogger: + """Mathematical logging for fixed point systems""" + + def __init__(self, component_name: str): + self.component = component_name + self.entries: List[KleeneLogEntry] = [] + self.session_start = time.time() + self.operation_count = 0 + + def log(self, level: LogLevel, operation: str, + state_before: Optional[Dict] = None, + state_after: Optional[Dict] = None, + fixed_point: bool = False, + iterations: int = 0): + """Record a state transition""" + + entry = KleeneLogEntry( + level=level, + component=self.component, + operation=operation, + state_before=state_before, + state_after=state_after, + fixed_point_reached=fixed_point, + iteration_count=iterations + ) + + self.entries.append(entry) + self._emit_to_container(entry) + + def _emit_to_container(self, entry: KleeneLogEntry): + """Emit structured log to container with colors""" + # ANSI color codes + colors = { + "CRITICAL": "\033[91m", # Bright red + "ERROR": "\033[31m", # Red + "WARNING": "\033[33m", # Yellow + "INFO": "\033[32m", # Green + "DEBUG": "\033[36m", # Cyan + "TRACE": "\033[90m", # Gray + "RESET": "\033[0m", # Reset + "BOLD": "\033[1m", # Bold + "DIM": "\033[2m", # Dim + } + + color = colors.get(entry.level.value, colors["RESET"]) + reset = colors["RESET"] + dim = colors["DIM"] + + # Format with colors + print(f"{color}[KLEENE]{reset} {color}{entry.level.value:8}{reset} | " + f"{dim}{entry.component:20}{reset} | " + f"{entry.operation:30} | " + f"Iter:{entry.iteration_count:3} | " + f"Fixed:{'Y' if entry.fixed_point_reached else 'N':1} | " + f"{dim}Hash:{entry.hash_value}{reset}") + + @contextmanager + def observe_operation(self, operation: str, initial_state: Dict): + """Context manager for observing operations""" + self.operation_count += 1 + iterations = 0 + + try: + self.log(LogLevel.DEBUG, f"{operation}_start", + state_before=initial_state) + + # Yield control back to operation + yield self + + # Operation completed successfully + self.log(LogLevel.INFO, f"{operation}_complete", + fixed_point=True, iterations=iterations) + + except Exception as e: + self.log(LogLevel.ERROR, f"{operation}_failed", + state_after={"error": str(e)}) + raise + + def fixed_point(self, operation: str, final_state: Dict, iterations: int): + """Log successful fixed point convergence""" + self.log(LogLevel.INFO, f"{operation}_fixed_point", + state_after=final_state, + fixed_point=True, + iterations=iterations) + + def divergence(self, operation: str, state: Dict): + """Log when system diverges (no fixed point)""" + self.log(LogLevel.WARNING, f"{operation}_divergence", + state_after=state, + fixed_point=False) + + def critical_failure(self, operation: str, error_state: Dict): + """Log critical system failure""" + self.log(LogLevel.CRITICAL, f"{operation}_critical", + state_after=error_state, + fixed_point=False) + + def get_session_hash(self) -> str: + """Get hash of entire session for verification""" + content = { + "component": self.component, + "start": self.session_start, + "operations": self.operation_count, + "entries": [e.hash_value for e in self.entries] + } + return hashlib.sha256(json.dumps(content).encode()).hexdigest() + + +# Global loggers for major components +_loggers: Dict[str, KleeneLogger] = {} + + +def get_kleene_logger(component: str) -> KleeneLogger: + """Get or create logger for component""" + if component not in _loggers: + _loggers[component] = KleeneLogger(component) + return _loggers[component] + + +# Convenience decorators +def log_fixed_point(operation: str): + """Decorator to automatically log fixed point operations""" + def decorator(func): + def wrapper(*args, **kwargs): + logger = get_kleene_logger(func.__module__) + start_state = {"args": str(args), "kwargs": str(kwargs)} + + try: + result = func(*args, **kwargs) + logger.fixed_point(operation, {"result": str(result)}, 1) + return result + except Exception as e: + logger.critical_failure(operation, {"error": str(e)}) + raise + return wrapper + return decorator + + +def log_iterations(operation: str): + """Decorator for operations that iterate to fixed points""" + def decorator(func): + def wrapper(*args, **kwargs): + logger = get_kleene_logger(func.__module__) + + # Simulate iteration counting (real implementation would track) + result = func(*args, **kwargs) + iterations = getattr(result, 'iterations', 1) + + logger.fixed_point(operation, {"converged": True}, iterations) + return result + return wrapper + return decorator diff --git a/cascade/logging/log_manager.py b/cascade/logging/log_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc01b907f63c89b414434020bd4959f236f1cdd --- /dev/null +++ b/cascade/logging/log_manager.py @@ -0,0 +1,266 @@ +""" +CASCADE Log Manager +Orchestrates the tsunami of data into ordered causation troops. + +Manages log levels, routing, and the beautiful display of system truth. +""" + +import os +import sys +import time +from typing import Dict, List, Optional, Any +from dataclasses import dataclass +from enum import Enum + +from .kleene_logger import KleeneLogger, LogLevel +from .interpretive_logger import InterpretiveLogger, ImpactLevel + + +class LogMode(Enum): + """The two modes of logging excellence""" + KLEENE = "kleene" # Mathematical precision + INTERPRETIVE = "interpretive" # Human stories + DUAL = "dual" # Both simultaneously + + +@dataclass +class LogConfig: + """Configuration for logging behavior""" + mode: LogMode = LogMode.DUAL + min_level_kleene: LogLevel = LogLevel.INFO + min_level_interpretive: ImpactLevel = ImpactLevel.LOW + show_metrics: bool = True + show_timestamps: bool = True + color_output: bool = True + file_output: bool = False + max_file_size_mb: int = 100 + + +class CascadeLogManager: + """The conductor of your causation orchestra""" + + def __init__(self, config: Optional[LogConfig] = None): + self.config = config or LogConfig() + self.kleene_loggers: Dict[str, KleeneLogger] = {} + self.interpretive_loggers: Dict[str, InterpretiveLogger] = {} + self.start_time = time.time() + self.operation_count = 0 + + # Initialize display + self._setup_display() + + def _setup_display(self): + """Setup beautiful terminal output""" + if self.config.color_output: + # Enable ANSI colors + sys.stdout.reconfigure(encoding='utf-8') + + # Print header + self._print_header() + + def _print_header(self): + """Print beautiful cascade header with colors""" + # ANSI color codes + colors = { + "WAVE": "\033[94m", # Bright blue + "BRIDGE": "\033[96m", # Cyan + "BOLD": "\033[1m", + "DIM": "\033[2m", + "RESET": "\033[0m", + "GREEN": "\033[32m", + "YELLOW": "\033[33m", + } + + wave = colors["WAVE"] + bridge = colors["BRIDGE"] + bold = colors["BOLD"] + dim = colors["DIM"] + reset = colors["RESET"] + green = colors["GREEN"] + yellow = colors["YELLOW"] + + print(f"\n{bold}{'='*80}{reset}") + print(f"{wave}🌊{reset} {bold}CASCADE // TRUTH INFRASTRUCTURE{reset} {bridge}🧠{reset}") + print(f"{bold}{'='*80}{reset}") + print(f"{bold}Mode:{reset} {green}{self.config.mode.value.upper()}{reset}") + print(f"{bold}Started:{reset} {dim}{time.strftime('%Y-%m-%d %H:%M:%S')}{reset}") + print(f"{bold}{'='*80}{reset}\n") + + def register_component(self, component: str, system: str = "CASCADE"): + """Register a component for logging""" + if self.config.mode in [LogMode.KLEENE, LogMode.DUAL]: + kleene = KleeneLogger(component) + self.kleene_loggers[component] = kleene + + if self.config.mode in [LogMode.INTERPRETIVE, LogMode.DUAL]: + interpretive = InterpretiveLogger(system) + self.interpretive_loggers[system] = interpretive + + def log_operation(self, component: str, operation: str, + level: LogLevel = LogLevel.INFO, + impact: ImpactLevel = ImpactLevel.LOW, + details: Optional[Dict] = None): + """Log an operation across all active loggers""" + self.operation_count += 1 + + if self.config.mode in [LogMode.KLEENE, LogMode.DUAL]: + if component in self.kleene_loggers: + self.kleene_loggers[component].log( + level, operation, + state_before=details.get("before") if details else None, + state_after=details.get("after") if details else None, + fixed_point=details.get("fixed_point", False) if details else False, + iterations=details.get("iterations", 0) if details else 0 + ) + + if self.config.mode in [LogMode.INTERPRETIVE, LogMode.DUAL]: + # Find interpretive logger for component + system = details.get("system", "CASCADE") if details else "CASCADE" + if system in self.interpretive_loggers: + self.interpretive_loggers[system].log( + impact, component, operation, + context=details.get("context", "") if details else "", + consequence=details.get("consequence", "") if details else "", + metrics=details.get("metrics", {}) if details else {}, + recommendation=details.get("recommendation") if details else None + ) + + def get_session_stats(self) -> Dict[str, Any]: + """Get beautiful session statistics""" + total_kleene = sum(len(logger.entries) for logger in self.kleene_loggers.values()) + total_interpretive = sum(len(logger.entries) for logger in self.interpretive_loggers.values()) + + return { + "uptime_seconds": time.time() - self.start_time, + "operations": self.operation_count, + "kleene_entries": total_kleene, + "interpretive_entries": total_interpretive, + "active_components": len(self.kleene_loggers), + "active_systems": len(self.interpretive_loggers) + } + + def print_summary(self): + """Print beautiful session summary with colors""" + stats = self.get_session_stats() + + # ANSI color codes + colors = { + "BOLD": "\033[1m", + "DIM": "\033[2m", + "RESET": "\033[0m", + "CYAN": "\033[36m", + "GREEN": "\033[32m", + "YELLOW": "\033[33m", + "BLUE": "\033[34m", + "MAGENTA": "\033[35m", + } + + bold = colors["BOLD"] + dim = colors["DIM"] + reset = colors["RESET"] + cyan = colors["CYAN"] + green = colors["GREEN"] + yellow = colors["YELLOW"] + blue = colors["BLUE"] + magenta = colors["MAGENTA"] + + print(f"\n{bold}{'='*80}{reset}") + print(f"{cyan}📊 CASCADE SESSION SUMMARY{reset}") + print(f"{bold}{'='*80}{reset}") + print(f"{bold}Uptime:{reset} {stats['uptime_seconds']:.1f} seconds") + print(f"{bold}Operations:{reset} {green}{stats['operations']:,}{reset}") + print(f"{bold}Kleene Entries:{reset} {yellow}{stats['kleene_entries']:,}{reset}") + print(f"{bold}Interpretive Entries:{reset} {blue}{stats['interpretive_entries']:,}{reset}") + print(f"{bold}Active Components:{reset} {magenta}{stats['active_components']}{reset}") + print(f"{bold}Active Systems:{reset} {magenta}{stats['active_systems']}{reset}") + + if stats['kleene_entries'] > 0: + # Get session hash from first logger + first_logger = next(iter(self.kleene_loggers.values())) + print(f"{bold}Session Hash:{reset} {dim}{first_logger.get_session_hash()}{reset}") + + print(f"{bold}{'='*80}{reset}") + + def set_mode(self, mode: LogMode): + """Switch logging mode dynamically""" + old_mode = self.config.mode + self.config.mode = mode + + print(f"\n🔄 Logging mode changed: {old_mode.value} → {mode.value}") + + def enable_file_logging(self, filepath: str): + """Enable logging to file""" + self.config.file_output = True + # TODO: Implement file logging + print(f"📁 File logging enabled: {filepath}") + + +# Global log manager instance +_log_manager: Optional[CascadeLogManager] = None + + +def init_logging(config: Optional[LogConfig] = None) -> CascadeLogManager: + """Initialize the global CASCADE logging system""" + global _log_manager + _log_manager = CascadeLogManager(config) + return _log_manager + + +def get_log_manager() -> CascadeLogManager: + """Get the global log manager""" + global _log_manager + if _log_manager is None: + _log_manager = CascadeLogManager() + return _log_manager + + +def log(component: str, operation: str, context: str = "", consequence: str = "", + metrics: Dict[str, Any] = None, impact: str = "LOW", **kwargs): + """Quick log operation - convenience function""" + manager = get_log_manager() + manager.log_operation(component, operation, + details={ + "context": context, + "consequence": consequence, + "metrics": metrics or {}, + "impact": impact, + **kwargs + }) + + +def log_fixed_point(component: str, operation: str, iterations: int, **kwargs): + """Log successful fixed point""" + log(component, operation, + level=LogLevel.INFO, + impact=ImpactLevel.LOW, + details={ + "fixed_point": True, + "iterations": iterations, + **kwargs + }) + + +def log_error(component: str, operation: str, error: str, **kwargs): + """Log error condition""" + log(component, f"{operation}_error", + level=LogLevel.ERROR, + impact=ImpactLevel.HIGH, + details={ + "context": f"Operation failed: {error}", + "consequence": "System may be degraded", + "metrics": {"error": error}, + **kwargs + }) + + +def log_performance(component: str, metric: str, value: float, threshold: float): + """Log performance warning""" + log(component, f"performance_{metric}", + level=LogLevel.WARNING, + impact=ImpactLevel.MEDIUM, + details={ + "context": f"Performance metric {metric} exceeded threshold", + "consequence": "May impact system performance", + "metrics": {metric: value, "threshold": threshold}, + "recommendation": f"Optimize {metric} or scale resources" + }) diff --git a/cascade/observation.py b/cascade/observation.py new file mode 100644 index 0000000000000000000000000000000000000000..c9699f887cf7d43e3444f40549433c2c64f10727 --- /dev/null +++ b/cascade/observation.py @@ -0,0 +1,397 @@ +""" +CASCADE Observation Manager + +Connects the detective tabs (Observatory, Unity, System) to the lattice. + +Flow: +1. User runs observation through any tab +2. Observation creates provenance chain +3. Chain links to model identity (for model obs) or genesis (for data/system) +4. Chain saved to lattice +5. Optionally pinned to IPFS + +This is the integration layer between UI and lattice. +""" + +import json +import time +from pathlib import Path +from typing import Optional, Dict, Any, List +from dataclasses import dataclass, field + +from cascade.core.provenance import ProvenanceChain +from cascade.identity import ModelRegistry, ModelIdentity, create_model_identity +from cascade.genesis import get_genesis_root, link_to_genesis + + +@dataclass +class Observation: + """ + A single observation record in the lattice. + + Can be: + - Model observation (inference through Observatory) + - Data observation (entity resolution through Unity) + - System observation (log analysis through System tab) + """ + observation_id: str + observation_type: str # "model", "data", "system" + + # What was observed + source_id: str # Model ID, dataset ID, or log source + source_root: str # Merkle root of source identity + + # The observation data + chain: ProvenanceChain + merkle_root: str + + # Metadata + user_hash: Optional[str] = None # Anonymous user identifier + created_at: float = field(default_factory=time.time) + + # IPFS + cid: Optional[str] = None + + +class ObservationManager: + """ + Manages observations across all CASCADE tabs. + + Responsibilities: + - Link observations to model identities or genesis + - Save observations to lattice + - Track observation history + - Provide stats for lattice gateway + """ + + def __init__(self, lattice_dir: Path = None): + self.lattice_dir = lattice_dir or Path(__file__).parent.parent / "lattice" + self.observations_dir = self.lattice_dir / "observations" + self.observations_dir.mkdir(parents=True, exist_ok=True) + + # Model registry for linking model observations + self.model_registry = ModelRegistry(self.lattice_dir) + + # Genesis root + self.genesis_root = get_genesis_root() + + # In-memory observation index + self._observations: Dict[str, Observation] = {} + self._load_index() + + def _load_index(self): + """Load observation index from disk.""" + index_file = self.lattice_dir / "observation_index.json" + if index_file.exists(): + try: + index = json.loads(index_file.read_text()) + # Just load metadata, not full chains + for obs_id, meta in index.items(): + self._observations[obs_id] = meta + except: + pass + + def _save_index(self): + """Save observation index to disk.""" + index_file = self.lattice_dir / "observation_index.json" + # Save lightweight index + index = {} + for obs_id, obs in self._observations.items(): + if isinstance(obs, Observation): + index[obs_id] = { + "observation_id": obs.observation_id, + "observation_type": obs.observation_type, + "source_id": obs.source_id, + "source_root": obs.source_root, + "merkle_root": obs.merkle_root, + "created_at": obs.created_at, + "cid": obs.cid, + } + else: + index[obs_id] = obs + index_file.write_text(json.dumps(index, indent=2)) + + def observe_model( + self, + model_id: str, + chain: ProvenanceChain, + user_hash: Optional[str] = None, + **model_kwargs, + ) -> Observation: + """ + Record a model observation. + + Args: + model_id: HuggingFace model ID or local path + chain: Provenance chain from Observatory + user_hash: Anonymous user identifier + **model_kwargs: Additional model info (parameters, etc.) + + Returns: + Observation linked to model identity + """ + # Get or create model identity + identity = self.model_registry.get_or_create(model_id, **model_kwargs) + + # Link chain to model identity + if not chain.external_roots: + chain.external_roots = [] + if identity.merkle_root not in chain.external_roots: + chain.external_roots.append(identity.merkle_root) + + # Finalize chain if not already + if not chain.finalized: + chain.finalize() + + # Create observation record + obs_id = f"model_{chain.merkle_root}" + observation = Observation( + observation_id=obs_id, + observation_type="model", + source_id=model_id, + source_root=identity.merkle_root, + chain=chain, + merkle_root=chain.merkle_root, + user_hash=user_hash, + ) + + # Save chain to disk + self._save_observation(observation) + + return observation + + def observe_data( + self, + dataset_a: str, + dataset_b: str, + chain: ProvenanceChain, + user_hash: Optional[str] = None, + ) -> Observation: + """ + Record a data unity observation. + + Links directly to genesis (data doesn't have model identity). + """ + # Link to genesis + if not chain.external_roots: + chain.external_roots = [] + if self.genesis_root not in chain.external_roots: + chain.external_roots.append(self.genesis_root) + + if not chain.finalized: + chain.finalize() + + # Create observation + source_id = f"{dataset_a}::{dataset_b}" + obs_id = f"data_{chain.merkle_root}" + + observation = Observation( + observation_id=obs_id, + observation_type="data", + source_id=source_id, + source_root=self.genesis_root, + chain=chain, + merkle_root=chain.merkle_root, + user_hash=user_hash, + ) + + self._save_observation(observation) + return observation + + def observe_system( + self, + source_name: str, + chain: ProvenanceChain, + user_hash: Optional[str] = None, + ) -> Observation: + """ + Record a system log observation. + + Links directly to genesis. + """ + # Link to genesis + if not chain.external_roots: + chain.external_roots = [] + if self.genesis_root not in chain.external_roots: + chain.external_roots.append(self.genesis_root) + + if not chain.finalized: + chain.finalize() + + obs_id = f"system_{chain.merkle_root}" + + observation = Observation( + observation_id=obs_id, + observation_type="system", + source_id=source_name, + source_root=self.genesis_root, + chain=chain, + merkle_root=chain.merkle_root, + user_hash=user_hash, + ) + + self._save_observation(observation) + return observation + + def _save_observation(self, observation: Observation): + """Save observation to disk.""" + # Save to index + self._observations[observation.observation_id] = observation + self._save_index() + + # Save full chain + chain_file = self.observations_dir / f"{observation.merkle_root}.json" + chain_data = { + "observation_id": observation.observation_id, + "observation_type": observation.observation_type, + "source_id": observation.source_id, + "source_root": observation.source_root, + "user_hash": observation.user_hash, + "created_at": observation.created_at, + "cid": observation.cid, + "chain": observation.chain.to_dict() if hasattr(observation.chain, 'to_dict') else str(observation.chain), + } + chain_file.write_text(json.dumps(chain_data, indent=2, default=str)) + + def pin_observation(self, observation: Observation) -> Optional[str]: + """ + Pin observation to IPFS. + + Returns CID if successful. + """ + try: + from cascade.ipld import chain_to_cid, encode_to_dag_cbor + from cascade.web3_pin import pin_file + + # Convert to IPLD format + chain_data = observation.chain.to_dict() if hasattr(observation.chain, 'to_dict') else {} + cbor_data = encode_to_dag_cbor(chain_data) + + # Save CBOR + cbor_file = self.observations_dir / f"{observation.merkle_root}.cbor" + cbor_file.write_bytes(cbor_data) + + # Compute CID + cid = chain_to_cid(chain_data) + observation.cid = cid + + # Update index + self._save_observation(observation) + + return cid + except Exception as e: + print(f"Failed to pin observation: {e}") + return None + + def get_observation(self, merkle_root: str) -> Optional[Observation]: + """Get observation by merkle root.""" + for obs in self._observations.values(): + if isinstance(obs, Observation) and obs.merkle_root == merkle_root: + return obs + elif isinstance(obs, dict) and obs.get("merkle_root") == merkle_root: + return obs + return None + + def list_observations( + self, + observation_type: Optional[str] = None, + source_id: Optional[str] = None, + limit: int = 100, + ) -> List[Dict[str, Any]]: + """List observations with optional filters.""" + results = [] + + for obs in self._observations.values(): + if isinstance(obs, Observation): + obs_dict = { + "observation_id": obs.observation_id, + "observation_type": obs.observation_type, + "source_id": obs.source_id, + "merkle_root": obs.merkle_root, + "created_at": obs.created_at, + "cid": obs.cid, + } + else: + obs_dict = obs + + # Apply filters + if observation_type and obs_dict.get("observation_type") != observation_type: + continue + if source_id and source_id not in obs_dict.get("source_id", ""): + continue + + results.append(obs_dict) + + # Sort by time, newest first + results.sort(key=lambda x: x.get("created_at", 0), reverse=True) + + return results[:limit] + + def get_stats(self) -> Dict[str, Any]: + """Get lattice statistics.""" + obs_list = list(self._observations.values()) + + model_obs = [o for o in obs_list if (isinstance(o, Observation) and o.observation_type == "model") or (isinstance(o, dict) and o.get("observation_type") == "model")] + data_obs = [o for o in obs_list if (isinstance(o, Observation) and o.observation_type == "data") or (isinstance(o, dict) and o.get("observation_type") == "data")] + system_obs = [o for o in obs_list if (isinstance(o, Observation) and o.observation_type == "system") or (isinstance(o, dict) and o.get("observation_type") == "system")] + + # Count unique models + model_ids = set() + for o in model_obs: + if isinstance(o, Observation): + model_ids.add(o.source_id) + elif isinstance(o, dict): + model_ids.add(o.get("source_id", "")) + + return { + "total_observations": len(obs_list), + "model_observations": len(model_obs), + "data_observations": len(data_obs), + "system_observations": len(system_obs), + "unique_models": len(model_ids), + "registered_models": len(self.model_registry.list_all()), + "genesis_root": self.genesis_root, + } + + def get_model_observations(self, model_id: str) -> List[Dict[str, Any]]: + """Get all observations for a specific model.""" + return self.list_observations(observation_type="model", source_id=model_id) + + +# ============================================================================= +# SINGLETON INSTANCE +# ============================================================================= + +_manager: Optional[ObservationManager] = None + +def get_observation_manager() -> ObservationManager: + """Get singleton observation manager.""" + global _manager + if _manager is None: + _manager = ObservationManager() + return _manager + + +# ============================================================================= +# CLI +# ============================================================================= + +if __name__ == "__main__": + print("=== CASCADE Observation Manager ===\n") + + manager = get_observation_manager() + + # Show stats + stats = manager.get_stats() + print(f"Genesis: {stats['genesis_root']}") + print(f"Registered Models: {stats['registered_models']}") + print(f"Total Observations: {stats['total_observations']}") + print(f" - Model: {stats['model_observations']}") + print(f" - Data: {stats['data_observations']}") + print(f" - System: {stats['system_observations']}") + print(f"Unique Models Observed: {stats['unique_models']}") + + # List recent observations + print("\nRecent Observations:") + for obs in manager.list_observations(limit=5): + print(f" [{obs['observation_type']}] {obs['source_id'][:40]}... → {obs['merkle_root']}") diff --git a/cascade/observe.py b/cascade/observe.py new file mode 100644 index 0000000000000000000000000000000000000000..3a42f0ee94a0ee7ef2cb9b58161b426f2633f861 --- /dev/null +++ b/cascade/observe.py @@ -0,0 +1,231 @@ +""" +Cascade Observer CLI. + +Wraps a target process and observes its output. + +Usage: + python -m cascade.observe --cmd "python path/to/train.py --args..." + +This module: +1. Wraps the target process +2. Pipes stdout/stderr -> Cascade Adapter +3. Writes events to tape file (JSONL) and human log (Markdown) +4. Emits events to event_queue for external consumers + +For visualization, point a consumer at the event_queue or load the tape file +into your preferred visualization tool. +""" + +import sys +import subprocess +import argparse +import time +import json +import shlex +import shutil +from pathlib import Path +from queue import Queue + +# Ensure package root is in path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from cascade import Monitor + +# Shared event queue for external consumers (e.g., custom UIs) +event_queue: Queue = Queue() + + +def scoop_the_poop(log_dir: Path): + """ + Baggies system - archive old logs on startup. + Keeps the logs folder clean. Old sessions go to baggies/. + """ + baggies_dir = log_dir / "baggies" + baggies_dir.mkdir(parents=True, exist_ok=True) + + # Find all old log files (not the current session) + tape_files = list(log_dir.glob("cascade_tape_*.jsonl")) + log_files = list(log_dir.glob("cascade_log_*.md")) + + moved_count = 0 + for f in tape_files + log_files: + if f.parent == log_dir: # Only files in root logs/, not baggies/ + dest = baggies_dir / f.name + try: + shutil.move(str(f), str(dest)) + moved_count += 1 + except Exception as e: + print(f"[CASCADE] Could not archive {f.name}: {e}") + + if moved_count > 0: + print(f"[CASCADE] 🧹 Scooped {moved_count} old logs → baggies/") + + +def main(): + parser = argparse.ArgumentParser( + prog="cascade", + description="🌊 Cascade - Real-Time Neural Network Observability", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + cascade --cmd "python train.py" + cascade --cmd "python train.py --epochs=10" + cascade --cmd "python train.py" --cwd /path/to/project + +Events are written to tape files in the log directory. + """ + ) + + # Support both "cascade --cmd" and "cascade observe --cmd" + subparsers = parser.add_subparsers(dest="command") + observe_parser = subparsers.add_parser("observe", help="Observe a training process") + + # Add args to both main parser and observe subparser + for p in [parser, observe_parser]: + p.add_argument("--cmd", required=True, help="Command to run the target process") + p.add_argument("--cwd", default=None, help="Working directory for the target (absolute path)") + p.add_argument("--log-dir", default="./logs", help="Directory for session tapes") + p.add_argument("--quiet", "-q", action="store_true", help="Suppress console output") + + args = parser.parse_args() + + # Resolve working directory to absolute + if args.cwd: + work_dir = Path(args.cwd).resolve() + else: + work_dir = Path.cwd() + + # 0. Setup Session Tape (The Excrement/Product) + log_dir = Path(args.log_dir).resolve() + log_dir.mkdir(parents=True, exist_ok=True) + + # 🧹 Scoop old logs before starting new session + scoop_the_poop(log_dir) + + session_id = int(time.time()) + + # 1. Machine Tape (JSONL) + tape_path = log_dir / f"cascade_tape_{session_id}.jsonl" + tape_file = open(tape_path, "a", encoding="utf-8") + + # 2. Human Log (Markdown) + human_path = log_dir / f"cascade_log_{session_id}.md" + human_file = open(human_path, "a", encoding="utf-8") + + # Header for Human Log + human_file.write(f"# CASCADE MISSION LOG // SESSION {session_id}\n") + human_file.write(f"**Target:** `{args.cmd}`\n") + human_file.write(f"**Date:** {time.strftime('%Y-%m-%d %H:%M:%S')}\n") + human_file.write("---\n\n") + human_file.flush() + + print("="*60) + print("CASCADE // OBSERVER") + print(f"Target: {args.cmd}") + print(f"Tape: {tape_path.absolute()}") + print(f"Log: {human_path.absolute()}") + print("="*60) + + # Init Monitor + monitor = Monitor("symbiont_alpha") + + def write_human_entry(evt): + """Convert an event into an articulate log entry.""" + t_str = time.strftime('%H:%M:%S', time.localtime(evt.timestamp)) + + # Narrative construction based on event type + if evt.event_type == "error": + icon = "🔴" + narrative = f"CRITICAL FAILURE in **{evt.component}**." + elif evt.event_type == "warning": + icon = "⚠️" + narrative = f"Warning signal detected from **{evt.component}**." + elif evt.event_type == "state_change": + icon = "🔄" + narrative = f"State transition observed in **{evt.component}**." + elif "loss" in str(evt.data): + icon = "📉" + narrative = f"Optimization step completed by **{evt.component}**." + else: + icon = "ℹ️" + narrative = f"Standard event recorded from **{evt.component}**." + + # Write readable block + human_file.write(f"### {icon} {t_str} // {evt.event_type.upper()}\n") + human_file.write(f"{narrative}\n") + if evt.data: + # Format data as a clean list or quote + human_file.write("```yaml\n") + for k, v in evt.data.items(): + human_file.write(f"{k}: {v}\n") + human_file.write("```\n") + human_file.write("\n") + human_file.flush() + + # Launch Target + try: + # Split command for subprocess if it's a string + cmd_parts = shlex.split(args.cmd) + + process = subprocess.Popen( + cmd_parts, + cwd=args.cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1 + ) + + print(f"[CASCADE] Linked to target. Recording to tape & log...") + + for line in process.stdout: + line = line.strip() + if not line: continue + + # Feed Adapter + event = monitor.observe(line) + + # Build payload with FULL wealth: metrics + triage + raw + metrics_summary = monitor.metrics.summary() + triage_status = monitor.metrics.triage() + + payload = { + "event": { + "event_id": event.event_id, + "timestamp": event.timestamp, + "component": event.component, + "event_type": event.event_type, + "data": event.data, + "raw": line, # Include original line for drill-down + }, + "metrics": metrics_summary, + "triage": triage_status, + } + + # Emit to queue for external consumers + event_queue.put(payload) + + # Write to Tape (Machine) + tape_file.write(json.dumps(payload) + "\n") + tape_file.flush() + + # Write to Log (Human) + write_human_entry(event) + + # Echo to console (unless quiet) + if not args.quiet: + print(f"[RAW] {line}") + + except KeyboardInterrupt: + print("\n[CASCADE] Detaching...") + except Exception as e: + print(f"[CASCADE] Error: {e}") + finally: + tape_file.close() + human_file.close() + if 'process' in locals() and process.poll() is None: + process.terminate() + print(f"[CASCADE] Session complete. Tape: {tape_path}") + +if __name__ == "__main__": + main() diff --git a/cascade/patches/__init__.py b/cascade/patches/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c90eced6efa0a6931f4f458f442935d52c72e0e1 --- /dev/null +++ b/cascade/patches/__init__.py @@ -0,0 +1,19 @@ +""" +CASCADE Patches - Auto-intercept LLM provider libraries + +Each patch module wraps a provider's API to automatically emit receipts. +""" + +from .openai_patch import patch_openai +from .anthropic_patch import patch_anthropic +from .huggingface_patch import patch_huggingface +from .ollama_patch import patch_ollama +from .litellm_patch import patch_litellm + +__all__ = [ + "patch_openai", + "patch_anthropic", + "patch_huggingface", + "patch_ollama", + "patch_litellm", +] diff --git a/cascade/patches/anthropic_patch.py b/cascade/patches/anthropic_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..306f2bc5af49c28045720c5b110091b25dd6d243 --- /dev/null +++ b/cascade/patches/anthropic_patch.py @@ -0,0 +1,124 @@ +""" +Anthropic API Patch - Intercepts anthropic.messages.create() etc. +""" + +import functools +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..sdk import CascadeSDK + + +def patch_anthropic(sdk: "CascadeSDK"): + """ + Patch the Anthropic library to emit receipts on every call. + + Intercepts: + - anthropic.Anthropic().messages.create() + - anthropic.AsyncAnthropic().messages.create() + """ + import anthropic + + def extract_model_id(kwargs, response): + """Extract canonical model identifier.""" + model = kwargs.get("model", "unknown") + # claude-3-opus-20240229, claude-3-sonnet-20240229, etc. + return f"anthropic/{model}" + + def extract_input(kwargs): + """Extract input from kwargs.""" + messages = kwargs.get("messages", []) + if messages: + # Get the last user message + user_msgs = [m for m in messages if m.get("role") == "user"] + if user_msgs: + content = user_msgs[-1].get("content", "") + # Content could be string or list of content blocks + if isinstance(content, list): + texts = [c.get("text", "") for c in content if c.get("type") == "text"] + return " ".join(texts) + return content + return "" + + def extract_output(response): + """Extract output from response.""" + try: + if hasattr(response, "content") and response.content: + # Content is a list of content blocks + texts = [] + for block in response.content: + if hasattr(block, "text"): + texts.append(block.text) + return " ".join(texts) + return str(response) + except: + return str(response) + + def extract_metrics(response, kwargs): + """Extract usage metrics.""" + metrics = {} + try: + if hasattr(response, "usage") and response.usage: + metrics["input_tokens"] = response.usage.input_tokens + metrics["output_tokens"] = response.usage.output_tokens + + # Add request params + metrics["max_tokens"] = kwargs.get("max_tokens") + metrics["temperature"] = kwargs.get("temperature", 1.0) + + # Stop reason + if hasattr(response, "stop_reason"): + metrics["stop_reason"] = response.stop_reason + except: + pass + return metrics + + def wrap_messages_create(original): + @functools.wraps(original) + def wrapper(*args, **kwargs): + # Call original + response = original(*args, **kwargs) + + # Emit receipt (non-blocking) + try: + sdk.observe( + model_id=extract_model_id(kwargs, response), + input_data=extract_input(kwargs), + output_data=extract_output(response), + metrics=extract_metrics(response, kwargs), + context={ + "provider": "anthropic", + "endpoint": "messages", + "system": kwargs.get("system", "")[:200] # First 200 chars of system prompt + } + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] Anthropic observation failed: {e}") + + return response + return wrapper + + # Patch the Anthropic client class + if hasattr(anthropic, "Anthropic"): + _OriginalClient = anthropic.Anthropic + + class PatchedAnthropic(_OriginalClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Wrap the instance methods + original_create = self.messages.create + self.messages.create = wrap_messages_create(original_create) + + anthropic.Anthropic = PatchedAnthropic + + # Patch AsyncAnthropic if available + if hasattr(anthropic, "AsyncAnthropic"): + _OriginalAsyncClient = anthropic.AsyncAnthropic + + class PatchedAsyncAnthropic(_OriginalAsyncClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: Implement async observation wrappers + + anthropic.AsyncAnthropic = PatchedAsyncAnthropic diff --git a/cascade/patches/huggingface_patch.py b/cascade/patches/huggingface_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..a41ec8e6fe164a35f651ff7406b88cccf026a328 --- /dev/null +++ b/cascade/patches/huggingface_patch.py @@ -0,0 +1,203 @@ +""" +HuggingFace Patch - Intercepts transformers and inference API calls. +""" + +import functools +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..sdk import CascadeSDK + + +def patch_huggingface(sdk: "CascadeSDK"): + """ + Patch HuggingFace libraries to emit receipts. + + Intercepts: + - transformers.pipeline().__call__() + - transformers.AutoModelForCausalLM.generate() + - huggingface_hub.InferenceClient() + """ + + # Try to patch transformers + try: + _patch_transformers(sdk) + except ImportError: + pass + + # Try to patch huggingface_hub inference + try: + _patch_inference_client(sdk) + except ImportError: + pass + + +def _patch_transformers(sdk: "CascadeSDK"): + """Patch transformers library.""" + import transformers + + def extract_model_id(pipe_or_model): + """Extract model identifier from pipeline or model.""" + try: + if hasattr(pipe_or_model, "model"): + model = pipe_or_model.model + else: + model = pipe_or_model + + # Try to get model name from config + if hasattr(model, "config"): + if hasattr(model.config, "_name_or_path"): + return f"hf/{model.config._name_or_path}" + if hasattr(model.config, "name_or_path"): + return f"hf/{model.config.name_or_path}" + + # Fallback to class name + return f"hf/{model.__class__.__name__}" + except: + return "hf/unknown" + + # Patch pipeline + _OriginalPipeline = transformers.pipeline + + @functools.wraps(_OriginalPipeline) + def patched_pipeline(*args, **kwargs): + pipe = _OriginalPipeline(*args, **kwargs) + + # Wrap the __call__ method + original_call = pipe.__call__ + + @functools.wraps(original_call) + def wrapped_call(*call_args, **call_kwargs): + result = original_call(*call_args, **call_kwargs) + + try: + # Extract input + input_data = call_args[0] if call_args else call_kwargs.get("inputs", "") + if isinstance(input_data, list): + input_data = input_data[0] if input_data else "" + + # Extract output + if isinstance(result, list) and result: + output = result[0] + if isinstance(output, dict): + output = output.get("generated_text", output.get("text", str(output))) + else: + output = str(result) + + sdk.observe( + model_id=extract_model_id(pipe), + input_data=str(input_data), + output_data=str(output), + metrics={"task": pipe.task if hasattr(pipe, "task") else "unknown"}, + context={"provider": "huggingface", "endpoint": "pipeline"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] HuggingFace pipeline observation failed: {e}") + + return result + + pipe.__call__ = wrapped_call + return pipe + + transformers.pipeline = patched_pipeline + + # Patch AutoModelForCausalLM.generate + if hasattr(transformers, "AutoModelForCausalLM"): + _OriginalAutoModel = transformers.AutoModelForCausalLM + + class PatchedAutoModelForCausalLM(_OriginalAutoModel): + @classmethod + def from_pretrained(cls, *args, **kwargs): + model = super().from_pretrained(*args, **kwargs) + + # Wrap generate + original_generate = model.generate + + @functools.wraps(original_generate) + def wrapped_generate(*gen_args, **gen_kwargs): + result = original_generate(*gen_args, **gen_kwargs) + + try: + sdk.observe( + model_id=extract_model_id(model), + input_data=f"tokens:{gen_args[0].shape if gen_args else 'unknown'}", + output_data=f"tokens:{result.shape if hasattr(result, 'shape') else 'unknown'}", + metrics={ + "max_new_tokens": gen_kwargs.get("max_new_tokens"), + "temperature": gen_kwargs.get("temperature", 1.0), + }, + context={"provider": "huggingface", "endpoint": "generate"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] HuggingFace generate observation failed: {e}") + + return result + + model.generate = wrapped_generate + return model + + transformers.AutoModelForCausalLM = PatchedAutoModelForCausalLM + + +def _patch_inference_client(sdk: "CascadeSDK"): + """Patch huggingface_hub InferenceClient.""" + from huggingface_hub import InferenceClient + + _OriginalClient = InferenceClient + + class PatchedInferenceClient(_OriginalClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._model_id = kwargs.get("model", args[0] if args else "unknown") + + def text_generation(self, prompt, **kwargs): + result = super().text_generation(prompt, **kwargs) + + try: + sdk.observe( + model_id=f"hf-inference/{self._model_id}", + input_data=prompt, + output_data=result, + metrics={ + "max_new_tokens": kwargs.get("max_new_tokens"), + "temperature": kwargs.get("temperature"), + }, + context={"provider": "huggingface", "endpoint": "inference_api"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] HF Inference observation failed: {e}") + + return result + + def chat_completion(self, messages, **kwargs): + result = super().chat_completion(messages, **kwargs) + + try: + # Extract last user message + user_msgs = [m for m in messages if m.get("role") == "user"] + input_text = user_msgs[-1].get("content", "") if user_msgs else "" + + # Extract output + output_text = "" + if hasattr(result, "choices") and result.choices: + output_text = result.choices[0].message.content + + sdk.observe( + model_id=f"hf-inference/{self._model_id}", + input_data=input_text, + output_data=output_text, + metrics={}, + context={"provider": "huggingface", "endpoint": "chat_completion"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] HF Inference observation failed: {e}") + + return result + + # Replace the class + import huggingface_hub + huggingface_hub.InferenceClient = PatchedInferenceClient diff --git a/cascade/patches/litellm_patch.py b/cascade/patches/litellm_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8f86c02bb7c9477168c13b11a582c031515df5 --- /dev/null +++ b/cascade/patches/litellm_patch.py @@ -0,0 +1,176 @@ +""" +LiteLLM Patch - Intercepts litellm.completion() which unifies all providers. + +LiteLLM is particularly valuable because it's a universal interface - +patching it catches calls to OpenAI, Anthropic, Cohere, Azure, Bedrock, +Vertex AI, Ollama, and many more through a single integration point. +""" + +import functools +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..sdk import CascadeSDK + + +def patch_litellm(sdk: "CascadeSDK"): + """ + Patch LiteLLM to emit receipts on every call. + + Intercepts: + - litellm.completion() + - litellm.acompletion() + - litellm.text_completion() + """ + import litellm + + def extract_model_id(kwargs, response=None): + """Extract canonical model identifier.""" + model = kwargs.get("model", "unknown") + # LiteLLM uses provider/model format: openai/gpt-4, anthropic/claude-3, etc. + # If no provider prefix, add litellm/ + if "/" not in model: + return f"litellm/{model}" + return model + + def extract_input(kwargs): + """Extract input from kwargs.""" + messages = kwargs.get("messages", []) + if messages: + user_msgs = [m for m in messages if m.get("role") == "user"] + if user_msgs: + content = user_msgs[-1].get("content", "") + if isinstance(content, list): + texts = [c.get("text", "") for c in content if isinstance(c, dict) and c.get("type") == "text"] + return " ".join(texts) + return content + return kwargs.get("prompt", "") + + def extract_output(response): + """Extract output from response.""" + try: + if hasattr(response, "choices") and response.choices: + choice = response.choices[0] + if hasattr(choice, "message") and choice.message: + return choice.message.content + if hasattr(choice, "text"): + return choice.text + return str(response) + except: + return str(response) + + def extract_metrics(response, kwargs): + """Extract usage metrics.""" + metrics = {} + try: + if hasattr(response, "usage") and response.usage: + if hasattr(response.usage, "prompt_tokens"): + metrics["prompt_tokens"] = response.usage.prompt_tokens + if hasattr(response.usage, "completion_tokens"): + metrics["completion_tokens"] = response.usage.completion_tokens + if hasattr(response.usage, "total_tokens"): + metrics["total_tokens"] = response.usage.total_tokens + + # Request params + metrics["temperature"] = kwargs.get("temperature", 1.0) + metrics["max_tokens"] = kwargs.get("max_tokens") + + # Response metadata + if hasattr(response, "model"): + metrics["response_model"] = response.model + except: + pass + return {k: v for k, v in metrics.items() if v is not None} + + # Patch litellm.completion + if hasattr(litellm, "completion"): + _original_completion = litellm.completion + + @functools.wraps(_original_completion) + def patched_completion(*args, **kwargs): + response = _original_completion(*args, **kwargs) + + try: + sdk.observe( + model_id=extract_model_id(kwargs, response), + input_data=extract_input(kwargs), + output_data=extract_output(response), + metrics=extract_metrics(response, kwargs), + context={ + "provider": "litellm", + "endpoint": "completion", + "api_base": kwargs.get("api_base", "default") + } + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] LiteLLM completion observation failed: {e}") + + return response + + litellm.completion = patched_completion + + # Patch litellm.text_completion + if hasattr(litellm, "text_completion"): + _original_text_completion = litellm.text_completion + + @functools.wraps(_original_text_completion) + def patched_text_completion(*args, **kwargs): + response = _original_text_completion(*args, **kwargs) + + try: + sdk.observe( + model_id=extract_model_id(kwargs, response), + input_data=kwargs.get("prompt", ""), + output_data=extract_output(response), + metrics=extract_metrics(response, kwargs), + context={ + "provider": "litellm", + "endpoint": "text_completion" + } + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] LiteLLM text_completion observation failed: {e}") + + return response + + litellm.text_completion = patched_text_completion + + # Patch litellm.batch_completion if available + if hasattr(litellm, "batch_completion"): + _original_batch = litellm.batch_completion + + @functools.wraps(_original_batch) + def patched_batch_completion(*args, **kwargs): + responses = _original_batch(*args, **kwargs) + + try: + messages_list = kwargs.get("messages", []) + model = kwargs.get("model", "unknown") + + # Emit one observation per batch item + for i, response in enumerate(responses if responses else []): + input_msgs = messages_list[i] if i < len(messages_list) else [] + + sdk.observe( + model_id=extract_model_id({"model": model}, response), + input_data=extract_input({"messages": input_msgs}), + output_data=extract_output(response), + metrics=extract_metrics(response, kwargs), + context={ + "provider": "litellm", + "endpoint": "batch_completion", + "batch_index": i + } + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] LiteLLM batch_completion observation failed: {e}") + + return responses + + litellm.batch_completion = patched_batch_completion + + # Note: Async versions (acompletion, etc.) would need async wrappers + # TODO: Implement async observation for litellm.acompletion diff --git a/cascade/patches/ollama_patch.py b/cascade/patches/ollama_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..28a928b763af6fc938038f90a7aa687726b9f02e --- /dev/null +++ b/cascade/patches/ollama_patch.py @@ -0,0 +1,177 @@ +""" +Ollama Patch - Intercepts ollama library calls. +""" + +import functools +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..sdk import CascadeSDK + + +def patch_ollama(sdk: "CascadeSDK"): + """ + Patch the Ollama library to emit receipts. + + Intercepts: + - ollama.chat() + - ollama.generate() + - ollama.Client().chat() + - ollama.Client().generate() + """ + import ollama + + def extract_model_id(kwargs, response=None): + """Extract canonical model identifier.""" + model = kwargs.get("model", "unknown") + # Could be llama2, mistral, codellama:7b-instruct, etc. + return f"ollama/{model}" + + def extract_input_from_messages(messages): + """Extract input from message list.""" + if messages: + user_msgs = [m for m in messages if m.get("role") == "user"] + if user_msgs: + return user_msgs[-1].get("content", "") + return "" + + def extract_output(response): + """Extract output from response.""" + try: + if isinstance(response, dict): + # Chat response + if "message" in response: + return response["message"].get("content", "") + # Generate response + if "response" in response: + return response["response"] + return str(response) + except: + return str(response) + + def extract_metrics(response): + """Extract metrics from response.""" + metrics = {} + try: + if isinstance(response, dict): + metrics["eval_count"] = response.get("eval_count") + metrics["eval_duration"] = response.get("eval_duration") + metrics["prompt_eval_count"] = response.get("prompt_eval_count") + metrics["total_duration"] = response.get("total_duration") + except: + pass + return {k: v for k, v in metrics.items() if v is not None} + + # Patch ollama.chat + if hasattr(ollama, "chat"): + _original_chat = ollama.chat + + @functools.wraps(_original_chat) + def patched_chat(*args, **kwargs): + response = _original_chat(*args, **kwargs) + + try: + messages = kwargs.get("messages", args[1] if len(args) > 1 else []) + model = kwargs.get("model", args[0] if args else "unknown") + + sdk.observe( + model_id=f"ollama/{model}", + input_data=extract_input_from_messages(messages), + output_data=extract_output(response), + metrics=extract_metrics(response), + context={"provider": "ollama", "endpoint": "chat"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] Ollama chat observation failed: {e}") + + return response + + ollama.chat = patched_chat + + # Patch ollama.generate + if hasattr(ollama, "generate"): + _original_generate = ollama.generate + + @functools.wraps(_original_generate) + def patched_generate(*args, **kwargs): + response = _original_generate(*args, **kwargs) + + try: + model = kwargs.get("model", args[0] if args else "unknown") + prompt = kwargs.get("prompt", args[1] if len(args) > 1 else "") + + sdk.observe( + model_id=f"ollama/{model}", + input_data=prompt, + output_data=extract_output(response), + metrics=extract_metrics(response), + context={"provider": "ollama", "endpoint": "generate"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] Ollama generate observation failed: {e}") + + return response + + ollama.generate = patched_generate + + # Patch ollama.Client class + if hasattr(ollama, "Client"): + _OriginalClient = ollama.Client + + class PatchedClient(_OriginalClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Wrap chat + original_chat = self.chat + @functools.wraps(original_chat) + def wrapped_chat(*chat_args, **chat_kwargs): + response = original_chat(*chat_args, **chat_kwargs) + + try: + messages = chat_kwargs.get("messages", chat_args[1] if len(chat_args) > 1 else []) + model = chat_kwargs.get("model", chat_args[0] if chat_args else "unknown") + + sdk.observe( + model_id=f"ollama/{model}", + input_data=extract_input_from_messages(messages), + output_data=extract_output(response), + metrics=extract_metrics(response), + context={"provider": "ollama", "endpoint": "client.chat"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] Ollama client chat observation failed: {e}") + + return response + + self.chat = wrapped_chat + + # Wrap generate + original_generate = self.generate + @functools.wraps(original_generate) + def wrapped_generate(*gen_args, **gen_kwargs): + response = original_generate(*gen_args, **gen_kwargs) + + try: + model = gen_kwargs.get("model", gen_args[0] if gen_args else "unknown") + prompt = gen_kwargs.get("prompt", gen_args[1] if len(gen_args) > 1 else "") + + sdk.observe( + model_id=f"ollama/{model}", + input_data=prompt, + output_data=extract_output(response), + metrics=extract_metrics(response), + context={"provider": "ollama", "endpoint": "client.generate"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] Ollama client generate observation failed: {e}") + + return response + + self.generate = wrapped_generate + + ollama.Client = PatchedClient diff --git a/cascade/patches/openai_patch.py b/cascade/patches/openai_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6f3f4f59f2e19e53c024f9bffc52ae25c93720 --- /dev/null +++ b/cascade/patches/openai_patch.py @@ -0,0 +1,151 @@ +""" +OpenAI API Patch - Intercepts openai.chat.completions.create() etc. +""" + +import functools +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..sdk import CascadeSDK + + +def patch_openai(sdk: "CascadeSDK"): + """ + Patch the OpenAI library to emit receipts on every call. + + Intercepts: + - openai.chat.completions.create() + - openai.completions.create() + - openai.Client().chat.completions.create() + """ + import openai + + # Store original methods + _original_chat_create = None + _original_completions_create = None + + def extract_model_id(kwargs, response): + """Extract canonical model identifier.""" + model = kwargs.get("model", "unknown") + # Could be gpt-4, gpt-4-turbo, gpt-4-0125-preview, etc. + return f"openai/{model}" + + def extract_input(kwargs): + """Extract input from kwargs.""" + messages = kwargs.get("messages", []) + if messages: + # Get the last user message + user_msgs = [m for m in messages if m.get("role") == "user"] + if user_msgs: + return user_msgs[-1].get("content", "") + return kwargs.get("prompt", "") + + def extract_output(response): + """Extract output from response.""" + try: + # Chat completion + if hasattr(response, "choices") and response.choices: + choice = response.choices[0] + if hasattr(choice, "message"): + return choice.message.content + elif hasattr(choice, "text"): + return choice.text + return str(response) + except: + return str(response) + + def extract_metrics(response, kwargs): + """Extract usage metrics.""" + metrics = {} + try: + if hasattr(response, "usage") and response.usage: + metrics["prompt_tokens"] = response.usage.prompt_tokens + metrics["completion_tokens"] = response.usage.completion_tokens + metrics["total_tokens"] = response.usage.total_tokens + + # Add request params + metrics["temperature"] = kwargs.get("temperature", 1.0) + metrics["max_tokens"] = kwargs.get("max_tokens") + except: + pass + return metrics + + def wrap_chat_create(original): + @functools.wraps(original) + def wrapper(*args, **kwargs): + # Call original + response = original(*args, **kwargs) + + # Emit receipt (non-blocking) + try: + sdk.observe( + model_id=extract_model_id(kwargs, response), + input_data=extract_input(kwargs), + output_data=extract_output(response), + metrics=extract_metrics(response, kwargs), + context={"provider": "openai", "endpoint": "chat.completions"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] OpenAI observation failed: {e}") + + return response + return wrapper + + def wrap_completions_create(original): + @functools.wraps(original) + def wrapper(*args, **kwargs): + response = original(*args, **kwargs) + + try: + sdk.observe( + model_id=extract_model_id(kwargs, response), + input_data=kwargs.get("prompt", ""), + output_data=extract_output(response), + metrics=extract_metrics(response, kwargs), + context={"provider": "openai", "endpoint": "completions"} + ) + except Exception as e: + if sdk.config.get("verbose"): + print(f"[CASCADE] OpenAI observation failed: {e}") + + return response + return wrapper + + # Patch the module-level client if it exists + if hasattr(openai, "chat") and hasattr(openai.chat, "completions"): + _original_chat_create = openai.chat.completions.create + openai.chat.completions.create = wrap_chat_create(_original_chat_create) + + if hasattr(openai, "completions"): + _original_completions_create = openai.completions.create + openai.completions.create = wrap_completions_create(_original_completions_create) + + # Patch the OpenAI client class + if hasattr(openai, "OpenAI"): + _OriginalClient = openai.OpenAI + + class PatchedOpenAI(_OriginalClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Wrap the instance methods + original_chat = self.chat.completions.create + self.chat.completions.create = wrap_chat_create(original_chat) + + if hasattr(self, "completions"): + original_comp = self.completions.create + self.completions.create = wrap_completions_create(original_comp) + + openai.OpenAI = PatchedOpenAI + + # Also patch AsyncOpenAI if available + if hasattr(openai, "AsyncOpenAI"): + _OriginalAsyncClient = openai.AsyncOpenAI + + class PatchedAsyncOpenAI(_OriginalAsyncClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # For async, we'd need async wrappers + # TODO: Implement async observation + + openai.AsyncOpenAI = PatchedAsyncOpenAI diff --git a/cascade/proxy.py b/cascade/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..c9935554f67ba949f10d31f32ae9145511f391d5 --- /dev/null +++ b/cascade/proxy.py @@ -0,0 +1,376 @@ +""" +CASCADE Proxy - Protocol-level AI observation. + +Works with ANY language, ANY framework, ANY client. +Just set environment variables and CASCADE sees everything. + +Usage: + # Start the proxy + python -m cascade.proxy + + # In another terminal, point your app at it + export OPENAI_BASE_URL=http://localhost:7777/v1 + export ANTHROPIC_BASE_URL=http://localhost:7777/anthropic + + # Run your app normally - CASCADE observes all calls + python your_agent.py + +The proxy forwards requests to the real API and emits receipts for every call. +""" + +import asyncio +import hashlib +import json +import os +import time +import uuid +from datetime import datetime, timezone +from typing import Dict, Any, Optional +from pathlib import Path + +try: + from aiohttp import web, ClientSession + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + +# Import CASCADE SDK for emission +from cascade.sdk import CascadeSDK + + +class CascadeProxy: + """ + HTTP proxy that intercepts LLM API calls and emits CASCADE receipts. + + Supported providers: + - OpenAI (and OpenAI-compatible APIs) + - Anthropic + - Cohere + - Mistral + - Any OpenAI-compatible endpoint + """ + + # Real API endpoints + ENDPOINTS = { + "openai": "https://api.openai.com", + "anthropic": "https://api.anthropic.com", + "cohere": "https://api.cohere.ai", + "mistral": "https://api.mistral.ai", + } + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 7777, + verbose: bool = True, + ): + self.host = host + self.port = port + self.verbose = verbose + self.sdk = CascadeSDK() + self.sdk.init(emit_async=True, verbose=verbose) + self.session: Optional[ClientSession] = None + + # Stats + self.stats = { + "requests": 0, + "receipts_emitted": 0, + "bytes_proxied": 0, + "start_time": None, + } + + async def start(self): + """Start the proxy server.""" + if not AIOHTTP_AVAILABLE: + print("ERROR: aiohttp required for proxy mode") + print("Install with: pip install aiohttp") + return + + self.session = ClientSession() + self.stats["start_time"] = time.time() + + app = web.Application() + + # Route all requests + app.router.add_route("*", "/{path:.*}", self.handle_request) + + runner = web.AppRunner(app) + await runner.setup() + + site = web.TCPSite(runner, self.host, self.port) + + print(f""" +╔══════════════════════════════════════════════════════════════╗ +║ CASCADE PROXY - Protocol-Level AI Observation ║ +╠══════════════════════════════════════════════════════════════╣ +║ Listening: http://{self.host}:{self.port} ║ +║ ║ +║ Set these environment variables in your app: ║ +║ ║ +║ export OPENAI_BASE_URL=http://localhost:{self.port}/v1 ║ +║ export ANTHROPIC_BASE_URL=http://localhost:{self.port}/anthropic║ +║ ║ +║ Then run your app normally. CASCADE sees everything. ║ +╚══════════════════════════════════════════════════════════════╝ +""") + + await site.start() + + # Keep running + while True: + await asyncio.sleep(3600) + + async def handle_request(self, request: web.Request) -> web.Response: + """Handle incoming request, proxy to real API, emit receipt.""" + path = request.path + + # Determine provider from path + provider, real_url = self._resolve_provider(path, request) + + if not real_url: + return web.json_response( + {"error": "Unknown provider. Use /v1/* for OpenAI or /anthropic/* for Anthropic"}, + status=400 + ) + + # Read request body + body = await request.read() + request_data = {} + try: + if body: + request_data = json.loads(body) + except: + pass + + # Forward headers (strip host, add auth if needed) + headers = dict(request.headers) + headers.pop("Host", None) + headers.pop("host", None) + + # Make request to real API + try: + async with self.session.request( + method=request.method, + url=real_url, + headers=headers, + data=body, + ) as response: + response_body = await response.read() + response_data = {} + try: + response_data = json.loads(response_body) + except: + pass + + # Emit receipt + self._emit_receipt(provider, request_data, response_data, path) + + # Update stats + self.stats["requests"] += 1 + self.stats["bytes_proxied"] += len(body) + len(response_body) + + # Return response to client + return web.Response( + body=response_body, + status=response.status, + headers={ + k: v for k, v in response.headers.items() + if k.lower() not in ("transfer-encoding", "content-encoding") + }, + ) + + except Exception as e: + if self.verbose: + print(f"[CASCADE PROXY] Error: {e}") + return web.json_response( + {"error": f"Proxy error: {str(e)}"}, + status=502 + ) + + def _resolve_provider(self, path: str, request: web.Request) -> tuple: + """Resolve which provider to forward to based on path.""" + + # OpenAI: /v1/* -> api.openai.com/v1/* + if path.startswith("/v1"): + return "openai", f"https://api.openai.com{path}" + + # Anthropic: /anthropic/* -> api.anthropic.com/* + if path.startswith("/anthropic"): + real_path = path[len("/anthropic"):] + return "anthropic", f"https://api.anthropic.com{real_path}" + + # Cohere: /cohere/* -> api.cohere.ai/* + if path.startswith("/cohere"): + real_path = path[len("/cohere"):] + return "cohere", f"https://api.cohere.ai{real_path}" + + # Mistral: /mistral/* -> api.mistral.ai/* + if path.startswith("/mistral"): + real_path = path[len("/mistral"):] + return "mistral", f"https://api.mistral.ai{real_path}" + + # Custom endpoint via header + custom_base = request.headers.get("X-Cascade-Forward-To") + if custom_base: + return "custom", f"{custom_base}{path}" + + return None, None + + def _emit_receipt( + self, + provider: str, + request_data: Dict[str, Any], + response_data: Dict[str, Any], + path: str, + ): + """Emit CASCADE receipt for this request/response.""" + try: + # Extract model ID + model_id = self._extract_model_id(provider, request_data, response_data) + + # Extract input + input_text = self._extract_input(provider, request_data) + + # Extract output + output_text = self._extract_output(provider, response_data) + + # Extract metrics + metrics = self._extract_metrics(provider, response_data, request_data) + + # Emit via SDK + self.sdk.observe( + model_id=model_id, + input_data=input_text, + output_data=output_text, + metrics=metrics, + context={ + "provider": provider, + "endpoint": path, + "via": "proxy", + } + ) + + self.stats["receipts_emitted"] += 1 + + if self.verbose: + print(f"[CASCADE] Receipt: {model_id} via proxy") + + except Exception as e: + if self.verbose: + print(f"[CASCADE] Failed to emit receipt: {e}") + + def _extract_model_id( + self, + provider: str, + request_data: Dict[str, Any], + response_data: Dict[str, Any], + ) -> str: + """Extract canonical model ID.""" + model = request_data.get("model") or response_data.get("model", "unknown") + return f"{provider}/{model}" + + def _extract_input(self, provider: str, request_data: Dict[str, Any]) -> str: + """Extract input text from request.""" + # Chat completion style + messages = request_data.get("messages", []) + if messages: + user_msgs = [m for m in messages if m.get("role") == "user"] + if user_msgs: + content = user_msgs[-1].get("content", "") + if isinstance(content, list): + texts = [c.get("text", "") for c in content if isinstance(c, dict)] + return " ".join(texts) + return str(content) + + # Completion style + return request_data.get("prompt", "") + + def _extract_output(self, provider: str, response_data: Dict[str, Any]) -> str: + """Extract output text from response.""" + # OpenAI style + choices = response_data.get("choices", []) + if choices: + choice = choices[0] + if "message" in choice: + return choice["message"].get("content", "") + if "text" in choice: + return choice["text"] + + # Anthropic style + content = response_data.get("content", []) + if content and isinstance(content, list): + texts = [c.get("text", "") for c in content if isinstance(c, dict)] + return " ".join(texts) + + return "" + + def _extract_metrics( + self, + provider: str, + response_data: Dict[str, Any], + request_data: Dict[str, Any], + ) -> Dict[str, Any]: + """Extract metrics from response.""" + metrics = {} + + # Usage stats + usage = response_data.get("usage", {}) + if usage: + metrics["prompt_tokens"] = usage.get("prompt_tokens") or usage.get("input_tokens") + metrics["completion_tokens"] = usage.get("completion_tokens") or usage.get("output_tokens") + metrics["total_tokens"] = usage.get("total_tokens") + + # Request params + metrics["temperature"] = request_data.get("temperature") + metrics["max_tokens"] = request_data.get("max_tokens") + + return {k: v for k, v in metrics.items() if v is not None} + + async def shutdown(self): + """Shutdown proxy.""" + if self.session: + await self.session.close() + self.sdk.shutdown() + + # Print stats + runtime = time.time() - self.stats["start_time"] if self.stats["start_time"] else 0 + print(f""" +╔══════════════════════════════════════════════════════════════╗ +║ CASCADE PROXY - Shutdown ║ +╠══════════════════════════════════════════════════════════════╣ +║ Runtime: {runtime:.1f}s ║ +║ Requests: {self.stats['requests']} ║ +║ Receipts: {self.stats['receipts_emitted']} ║ +║ Bytes: {self.stats['bytes_proxied']} ║ +╚══════════════════════════════════════════════════════════════╝ +""") + + +def run_proxy(host: str = "0.0.0.0", port: int = 7777, verbose: bool = True): + """Run the CASCADE proxy server.""" + proxy = CascadeProxy(host=host, port=port, verbose=verbose) + + try: + asyncio.run(proxy.start()) + except KeyboardInterrupt: + print("\nShutting down...") + asyncio.run(proxy.shutdown()) + + +# ============================================================================= +# CLI +# ============================================================================= + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="CASCADE Proxy - Protocol-level AI observation" + ) + parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", "-p", type=int, default=7777, help="Port to listen on") + parser.add_argument("--quiet", "-q", action="store_true", help="Suppress output") + + args = parser.parse_args() + + run_proxy(host=args.host, port=args.port, verbose=not args.quiet) diff --git a/cascade/py.typed b/cascade/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cascade/sdk.py b/cascade/sdk.py new file mode 100644 index 0000000000000000000000000000000000000000..373fbcb3873ddfee71cfcc5e73b13206bd021b68 --- /dev/null +++ b/cascade/sdk.py @@ -0,0 +1,250 @@ +""" +CASCADE SDK - Universal AI Observation Layer + +Usage: + import cascade + cascade.init() + + # Now every call emits a receipt automatically + import openai + response = openai.chat.completions.create(...) # Receipt emitted +""" + +import threading +import queue +from typing import Optional, Dict, Any, List +from datetime import datetime, timezone + +# Import our observation infrastructure +from .observation import ObservationManager +from .identity import ModelRegistry +from .genesis import ProvenanceChain + + +class CascadeSDK: + """Main SDK singleton - manages patching and emission.""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if CascadeSDK._initialized: + return + + self.observation_manager = ObservationManager() + self.model_registry = ModelRegistry() + self.emission_queue = queue.Queue() + self.background_thread = None + self.running = False + self.patched_providers = set() + self.config = { + "emit_async": True, + "lattice_path": "lattice/observations", + "verbose": False, + } + CascadeSDK._initialized = True + + def init(self, **kwargs): + """ + Initialize CASCADE and auto-patch available providers. + + Args: + emit_async: Whether to emit receipts in background (default: True) + verbose: Print when receipts are emitted (default: False) + providers: List of providers to patch, or 'all' (default: 'all') + """ + self.config.update(kwargs) + + # Start background emission thread + if self.config["emit_async"] and not self.running: + self.running = True + self.background_thread = threading.Thread( + target=self._emission_worker, + daemon=True + ) + self.background_thread.start() + + # Auto-patch available providers + providers = kwargs.get("providers", "all") + self._patch_providers(providers) + + if self.config["verbose"]: + print(f"[CASCADE] Initialized. Patched: {self.patched_providers}") + + return self + + def _patch_providers(self, providers): + """Patch LLM provider libraries.""" + from .patches import ( + patch_openai, + patch_anthropic, + patch_huggingface, + patch_ollama, + patch_litellm, + ) + + patch_map = { + "openai": patch_openai, + "anthropic": patch_anthropic, + "huggingface": patch_huggingface, + "ollama": patch_ollama, + "litellm": patch_litellm, + } + + if providers == "all": + providers = list(patch_map.keys()) + + for provider in providers: + if provider in patch_map: + try: + patch_map[provider](self) + self.patched_providers.add(provider) + except ImportError: + # Provider not installed, skip + pass + except Exception as e: + if self.config["verbose"]: + print(f"[CASCADE] Failed to patch {provider}: {e}") + + def _emission_worker(self): + """Background thread that processes emission queue.""" + while self.running: + try: + receipt_data = self.emission_queue.get(timeout=1.0) + self._emit_receipt(receipt_data) + except queue.Empty: + continue + except Exception as e: + if self.config["verbose"]: + print(f"[CASCADE] Emission error: {e}") + + def _emit_receipt(self, receipt_data: Dict[str, Any]): + """Actually write the receipt to lattice.""" + import hashlib + import uuid + + try: + # Create provenance chain for this observation + model_id = receipt_data["model_id"] + input_text = receipt_data["input"][:1000] # Truncate + output_text = receipt_data["output"][:2000] # Truncate + + # Compute hashes + input_hash = hashlib.sha256(input_text.encode()).hexdigest()[:16] + model_hash = hashlib.sha256(model_id.encode()).hexdigest()[:16] + session_id = str(uuid.uuid4())[:8] + + chain = ProvenanceChain( + session_id=session_id, + model_id=model_id, + model_hash=model_hash, + input_hash=input_hash, + ) + + # Add inference record + from cascade.core.provenance import ProvenanceRecord + import time + + record = ProvenanceRecord( + layer_name="inference", + layer_idx=0, + state_hash=hashlib.sha256(output_text.encode()).hexdigest()[:16], + parent_hashes=[input_hash], + params_hash=model_hash, + shape=[len(output_text)], + dtype="text", + stats={ + **receipt_data.get("metrics", {}), + "provider": receipt_data.get("context", {}).get("provider", "unknown"), + "timestamp": receipt_data.get("timestamp", datetime.now(timezone.utc).isoformat()), + }, + execution_order=0, + ) + chain.add_record(record) + chain.finalize() + + observation = self.observation_manager.observe_model( + model_id=model_id, + chain=chain, + user_hash=receipt_data.get("user_hash"), + ) + + if self.config["verbose"]: + print(f"[CASCADE] Receipt: {observation.merkle_root[:16]}... -> {model_id}") + + return observation + except Exception as e: + if self.config["verbose"]: + import traceback + print(f"[CASCADE] Failed to emit: {e}") + traceback.print_exc() + return None + + def observe( + self, + model_id: str, + input_data: Any, + output_data: Any, + metrics: Optional[Dict] = None, + context: Optional[Dict] = None + ): + """ + Manually emit an observation receipt. + + Called automatically by patches, but can be called directly. + """ + receipt_data = { + "model_id": model_id, + "input": str(input_data), + "output": str(output_data), + "metrics": metrics or {}, + "context": context or {}, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + + if self.config["emit_async"]: + self.emission_queue.put(receipt_data) + else: + self._emit_receipt(receipt_data) + + def shutdown(self): + """Stop background emission and flush queue.""" + self.running = False + if self.background_thread: + self.background_thread.join(timeout=5.0) + + # Flush remaining items + while not self.emission_queue.empty(): + try: + receipt_data = self.emission_queue.get_nowait() + self._emit_receipt(receipt_data) + except queue.Empty: + break + + +# Global SDK instance +_sdk = CascadeSDK() + + +def init(**kwargs): + """Initialize CASCADE observation layer.""" + return _sdk.init(**kwargs) + + +def observe(model_id: str, input_data: Any, output_data: Any, **kwargs): + """Manually emit an observation.""" + return _sdk.observe(model_id, input_data, output_data, **kwargs) + + +def shutdown(): + """Shutdown CASCADE (flush pending receipts).""" + return _sdk.shutdown() + + +# Convenience: allow `import cascade; cascade.init()` +__all__ = ["init", "observe", "shutdown", "CascadeSDK"] diff --git a/cascade/store.py b/cascade/store.py new file mode 100644 index 0000000000000000000000000000000000000000..322f09fa3bfcf09fd2f43233893fb2efa37a67d0 --- /dev/null +++ b/cascade/store.py @@ -0,0 +1,847 @@ +""" +CASCADE Store - Simple observe/query interface with HuggingFace + +The goal: make it as easy as possible to store and retrieve provenance. + + from cascade.store import observe, query + + # Write - saves locally + syncs to HuggingFace + receipt = observe(model_id="crafter", data={"reward": 2.1, "step": 100}) + print(receipt.cid) # bafyrei... (content hash for verification) + + # Read - queries local store + HuggingFace + past = query(model_id="crafter") + for obs in past: + print(obs["reward"], obs["cid"]) + +Architecture: + 1. Local SQLite index (fast queries) + 2. CBOR files for full data (content-addressed) + 3. HuggingFace datasets for sync (unlimited, free) + 4. CIDs computed for verification (no IPFS daemon needed) + +Dual-write model: + - User's own HF dataset (they own their data) + - Central cascade-observations dataset (Dreamer sees everything) + +No daemon. No server. Just `huggingface-cli login` once. +""" + +import hashlib +import json +import os +import sqlite3 +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional +import dag_cbor +from multiformats import CID, multihash + +from cascade.genesis import get_genesis_root + + +# ============================================================================= +# CONSTANTS +# ============================================================================= + +# HuggingFace datasets - UNLIMITED FREE STORAGE +CENTRAL_DATASET = "tostido/cascade-observations" # Dreamer reads this +USER_DATASET = os.environ.get("CASCADE_USER_DATASET") # Optional: user's own dataset + +# Default lattice directory (local cache) +DEFAULT_LATTICE_DIR = Path.home() / ".cascade" / "lattice" + +# IPFS gateways for fallback +IPFS_GATEWAYS = [ + "https://ipfs.io/ipfs/", + "https://dweb.link/ipfs/", + "https://gateway.pinata.cloud/ipfs/", +] + + +# ============================================================================= +# DATA STRUCTURES +# ============================================================================= + +@dataclass +class Receipt: + """ + A provenance receipt - proof that an observation happened. + + Content-addressed via CID. Can be verified by anyone. + """ + cid: str # IPFS content identifier + model_id: str # What model/system was observed + merkle_root: str # Chain hash (legacy compat) + timestamp: float # When + data: Dict[str, Any] # The observation data + parent_cid: Optional[str] = None # Links to previous observation + + def to_dict(self) -> Dict[str, Any]: + return { + "cid": self.cid, + "model_id": self.model_id, + "merkle_root": self.merkle_root, + "timestamp": self.timestamp, + "data": self.data, + "parent_cid": self.parent_cid, + } + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "Receipt": + return cls( + cid=d["cid"], + model_id=d["model_id"], + merkle_root=d["merkle_root"], + timestamp=d["timestamp"], + data=d["data"], + parent_cid=d.get("parent_cid"), + ) + + +# ============================================================================= +# CONTENT ADDRESSING +# ============================================================================= + +def compute_cid(data: bytes) -> str: + """ + Compute CIDv1 (base32) from bytes. + + CID = multicodec(dag-cbor) + multihash(sha256(data)) + """ + digest = hashlib.sha256(data).digest() + mh = multihash.wrap(digest, "sha2-256") + cid = CID("base32", 1, "dag-cbor", mh) + return str(cid) + + +def data_to_cid(data: Dict[str, Any]) -> tuple[str, bytes]: + """Convert dict to (CID, encoded bytes).""" + encoded = dag_cbor.encode(data) + cid = compute_cid(encoded) + return cid, encoded + + +# ============================================================================= +# LOCAL STORE +# ============================================================================= + +class LocalStore: + """ + SQLite-backed local store with CBOR files. + + Fast queries via SQLite index. + Full data in content-addressed CBOR files. + """ + + def __init__(self, lattice_dir: Path = None): + self.lattice_dir = lattice_dir or DEFAULT_LATTICE_DIR + self.lattice_dir.mkdir(parents=True, exist_ok=True) + + self.cbor_dir = self.lattice_dir / "cbor" + self.cbor_dir.mkdir(exist_ok=True) + + self.db_path = self.lattice_dir / "index.db" + self._init_db() + + def _init_db(self): + """Initialize SQLite schema.""" + conn = sqlite3.connect(self.db_path) + conn.execute(""" + CREATE TABLE IF NOT EXISTS observations ( + cid TEXT PRIMARY KEY, + model_id TEXT NOT NULL, + merkle_root TEXT NOT NULL, + timestamp REAL NOT NULL, + parent_cid TEXT, + pinned INTEGER DEFAULT 0, + created_at REAL DEFAULT (strftime('%s', 'now')) + ) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_model_id ON observations(model_id) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_timestamp ON observations(timestamp DESC) + """) + conn.commit() + conn.close() + + def save(self, receipt: Receipt) -> str: + """Save receipt to local store. Returns CID.""" + # Save CBOR file + data = receipt.to_dict() + cid, encoded = data_to_cid(data) + + cbor_path = self.cbor_dir / f"{cid}.cbor" + cbor_path.write_bytes(encoded) + + # Update receipt with computed CID (may differ if data changed) + receipt.cid = cid + + # Index in SQLite + conn = sqlite3.connect(self.db_path) + conn.execute(""" + INSERT OR REPLACE INTO observations + (cid, model_id, merkle_root, timestamp, parent_cid) + VALUES (?, ?, ?, ?, ?) + """, (cid, receipt.model_id, receipt.merkle_root, + receipt.timestamp, receipt.parent_cid)) + conn.commit() + conn.close() + + return cid + + def get(self, cid: str) -> Optional[Receipt]: + """Get receipt by CID.""" + cbor_path = self.cbor_dir / f"{cid}.cbor" + if not cbor_path.exists(): + return None + + data = dag_cbor.decode(cbor_path.read_bytes()) + return Receipt.from_dict(data) + + def query( + self, + model_id: Optional[str] = None, + since: Optional[float] = None, + limit: int = 100, + ) -> List[Receipt]: + """Query receipts with filters.""" + conn = sqlite3.connect(self.db_path) + + sql = "SELECT cid FROM observations WHERE 1=1" + params = [] + + if model_id: + sql += " AND model_id = ?" + params.append(model_id) + + if since: + sql += " AND timestamp > ?" + params.append(since) + + sql += " ORDER BY timestamp DESC LIMIT ?" + params.append(limit) + + cursor = conn.execute(sql, params) + cids = [row[0] for row in cursor.fetchall()] + conn.close() + + # Load full receipts + receipts = [] + for cid in cids: + receipt = self.get(cid) + if receipt: + receipts.append(receipt) + + return receipts + + def get_latest(self, model_id: str) -> Optional[Receipt]: + """Get most recent receipt for a model.""" + results = self.query(model_id=model_id, limit=1) + return results[0] if results else None + + def count(self, model_id: Optional[str] = None) -> int: + """Count observations.""" + conn = sqlite3.connect(self.db_path) + if model_id: + cursor = conn.execute( + "SELECT COUNT(*) FROM observations WHERE model_id = ?", + (model_id,) + ) + else: + cursor = conn.execute("SELECT COUNT(*) FROM observations") + count = cursor.fetchone()[0] + conn.close() + return count + + +# ============================================================================= +# IPFS GATEWAY ACCESS +# ============================================================================= + +def fetch_from_gateway(cid: str, timeout: float = 10.0) -> Optional[bytes]: + """ + Fetch data from public IPFS gateways. + + Tries multiple gateways in sequence. + No daemon needed. + """ + import requests + + for gateway in IPFS_GATEWAYS: + url = f"{gateway}{cid}" + try: + resp = requests.get(url, timeout=timeout) + if resp.status_code == 200: + return resp.content + except Exception: + continue + + return None + + +def fetch_receipt(cid: str, local_store: LocalStore = None) -> Optional[Receipt]: + """ + Fetch receipt by CID, checking local store first, then IPFS. + """ + # Check local first + if local_store: + receipt = local_store.get(cid) + if receipt: + return receipt + + # Try IPFS gateways + data = fetch_from_gateway(cid) + if data: + try: + decoded = dag_cbor.decode(data) + receipt = Receipt.from_dict(decoded) + + # Cache locally + if local_store: + local_store.save(receipt) + + return receipt + except Exception: + pass + + return None + + +# ============================================================================= +# HUGGINGFACE SYNC - Unlimited free storage +# ============================================================================= + +# (Uses constants from top of file: CENTRAL_DATASET, USER_DATASET) + + +def _upload_to_hf(filepath: Path, cid: str, dataset_id: str) -> bool: + """Upload a single observation to a HuggingFace dataset.""" + try: + from huggingface_hub import HfApi + api = HfApi() + api.upload_file( + path_or_fileobj=str(filepath), + path_in_repo=f"observations/{cid}.cbor", + repo_id=dataset_id, + repo_type="dataset", + ) + return True + except Exception: + return False + + +def sync_observation(cid: str, filepath: Path) -> dict: + """ + Sync a single observation to HuggingFace. + + Dual-write: + 1. Central dataset (jtwspace/cascade-observations) - Dreamer sees all + 2. User's dataset (if CASCADE_USER_DATASET set) - they own their data + """ + results = {"central": False, "user": False} + + # Always sync to central + results["central"] = _upload_to_hf(filepath, cid, CENTRAL_DATASET) + + # Optionally sync to user's dataset + if USER_DATASET: + results["user"] = _upload_to_hf(filepath, cid, USER_DATASET) + + return results + + +def sync_all() -> dict: + """ + Sync all local observations to HuggingFace. + + Returns: + {"synced": count, "failed": count} + """ + store = _get_store() + cbor_files = list(store.cbor_dir.glob("*.cbor")) + + synced = 0 + failed = 0 + + for cbor_path in cbor_files: + cid = cbor_path.stem + results = sync_observation(cid, cbor_path) + + if results["central"]: + synced += 1 + conn = sqlite3.connect(store.db_path) + conn.execute("UPDATE observations SET pinned = 1 WHERE cid = ?", (cid,)) + conn.commit() + conn.close() + else: + failed += 1 + + return {"synced": synced, "failed": failed} + + +def pull_from_hf(dataset_id: str = None) -> int: + """ + Pull observations from a HuggingFace dataset. + + Args: + dataset_id: HF dataset (default: central dataset) + + Returns: + Number of observations pulled + """ + try: + from huggingface_hub import HfApi, hf_hub_download, list_repo_files + except ImportError: + print("pip install huggingface_hub") + return 0 + + store = _get_store() + dataset_id = dataset_id or CENTRAL_DATASET + + # List observation files + try: + files = list_repo_files(dataset_id, repo_type="dataset") + obs_files = [f for f in files if f.startswith("observations/") and f.endswith(".cbor")] + except Exception as e: + print(f"Could not list dataset: {e}") + return 0 + + pulled = 0 + for file_path in obs_files: + cid = file_path.replace("observations/", "").replace(".cbor", "") + + local_path = store.cbor_dir / f"{cid}.cbor" + if local_path.exists(): + continue + + try: + downloaded = hf_hub_download( + repo_id=dataset_id, + filename=file_path, + repo_type="dataset", + local_dir=str(store.lattice_dir / "_hf_cache"), + ) + + import shutil + shutil.copy(downloaded, local_path) + + data = dag_cbor.decode(local_path.read_bytes()) + receipt = Receipt.from_dict(data) + + conn = sqlite3.connect(store.db_path) + conn.execute(""" + INSERT OR REPLACE INTO observations + (cid, model_id, merkle_root, timestamp, parent_cid, pinned) + VALUES (?, ?, ?, ?, ?, 1) + """, (cid, receipt.model_id, receipt.merkle_root, + receipt.timestamp, receipt.parent_cid)) + conn.commit() + conn.close() + + pulled += 1 + + except Exception as e: + print(f"Pull error for {cid[:16]}: {e}") + + return pulled + + +# ============================================================================= +# PUBLIC API +# ============================================================================= + +# Global store instance +_store: Optional[LocalStore] = None + + +def _get_store() -> LocalStore: + """Get or create global store.""" + global _store + if _store is None: + _store = LocalStore() + return _store + + +def _json_default(obj): + """Handle non-serializable types like numpy scalars.""" + if hasattr(obj, "item") and callable(obj.item): + return obj.item() + return str(obj) + + +def _sanitize(data: Any) -> Any: + """Recursively convert numpy types to python types.""" + if isinstance(data, dict): + return {k: _sanitize(v) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + return [_sanitize(x) for x in data] + elif hasattr(data, "item") and callable(data.item): + return data.item() + return data + + +def observe( + model_id: str, + data: Dict[str, Any], + parent_cid: Optional[str] = None, + sync: bool = True, +) -> Receipt: + """ + Record an observation. + + Args: + model_id: Identifier for the model/system (e.g., "crafter", "gpt-4") + data: Observation data (any JSON-serializable dict) + parent_cid: Optional CID of previous observation (for chains) + sync: Whether to sync to HuggingFace (default: True) + + Returns: + Receipt with CID + + Example: + receipt = observe("crafter", {"reward": 2.1, "step": 100}) + print(receipt.cid) # bafyrei... + """ + store = _get_store() + + # Sanitize data to ensure JSON/CBOR serializability (numpy types -> python) + data = _sanitize(data) + + # Get parent if chaining + if parent_cid is None: + latest = store.get_latest(model_id) + if latest: + parent_cid = latest.cid + + # Compute merkle root (legacy compat) + # Using default=_json_default to handle numpy types often found in AI systems + merkle_data = f"{model_id}:{json.dumps(data, sort_keys=True, default=_json_default)}:{time.time()}" + merkle_root = hashlib.sha256(merkle_data.encode()).hexdigest()[:16] + + # Link to genesis + genesis_root = get_genesis_root() + data["_genesis"] = genesis_root + data["_model_id"] = model_id + + # Create receipt + receipt = Receipt( + cid="", # Will be computed + model_id=model_id, + merkle_root=merkle_root, + timestamp=time.time(), + data=data, + parent_cid=parent_cid, + ) + + # Save locally (computes CID) + cid = store.save(receipt) + receipt.cid = cid + + # Auto-sync to HuggingFace (best-effort, non-blocking) + if sync: + try: + sync_observation(cid, store.cbor_dir / f"{cid}.cbor") + except Exception: + pass # Local save succeeded, that's what matters + + return receipt + + +def query( + model_id: Optional[str] = None, + since: Optional[float] = None, + limit: int = 100, + include_remote: bool = False, +) -> List[Receipt]: + """ + Query observations. + + Args: + model_id: Filter by model ID + since: Only observations after this timestamp + limit: Maximum results + include_remote: Also search IPFS gateways (slower) + + Returns: + List of receipts + + Example: + # Get all crafter observations + for receipt in query(model_id="crafter"): + print(receipt.data["reward"]) + """ + store = _get_store() + return store.query(model_id=model_id, since=since, limit=limit) + + +def get(cid: str) -> Optional[Receipt]: + """ + Get a specific observation by CID. + + Checks local store first, then IPFS gateways. + """ + store = _get_store() + return fetch_receipt(cid, store) + + +# ============================================================================= +# DISCOVERY - Leverages HuggingFace's catalog +# ============================================================================= + +def discover_models(dataset_id: str = None) -> Dict[str, int]: + """ + Discover all model_ids in the lattice by scanning HuggingFace. + + HuggingFace IS the catalog - we just read it. + + Returns: + Dict mapping model_id -> observation count + """ + dataset_id = dataset_id or CENTRAL_DATASET + + try: + from huggingface_hub import list_repo_files + files = list_repo_files(dataset_id, repo_type="dataset") + + # Parse model_ids from observation files + # Files are stored as: observations/{cid}.cbor + # We need to actually read to get model_id, OR + # use a manifest/index file + + # For now, return local knowledge + count remote files + store = _get_store() + local_models = {} + + conn = sqlite3.connect(store.db_path) + rows = conn.execute( + "SELECT model_id, COUNT(*) FROM observations GROUP BY model_id" + ).fetchall() + conn.close() + + for model_id, count in rows: + local_models[model_id] = count + + # Count total remote files + remote_count = len([f for f in files if f.startswith("observations/") and f.endswith(".cbor")]) + + return { + "local_models": local_models, + "remote_observation_count": remote_count, + "dataset": dataset_id, + } + + except ImportError: + # No huggingface_hub - return local only + store = _get_store() + conn = sqlite3.connect(store.db_path) + rows = conn.execute( + "SELECT model_id, COUNT(*) FROM observations GROUP BY model_id" + ).fetchall() + conn.close() + return {"local_models": {m: c for m, c in rows}, "remote": "unavailable"} + + except Exception as e: + return {"error": str(e)} + + +def discover_datasets(query_str: str = "cascade") -> List[Dict[str, Any]]: + """ + Search HuggingFace for cascade-related datasets. + + This is how users FIND each other's lattices. + + Returns: + List of dataset info dicts + """ + try: + from huggingface_hub import HfApi + api = HfApi() + + # Search for cascade datasets + datasets = api.list_datasets(search=query_str, limit=50) + + results = [] + for ds in datasets: + results.append({ + "id": ds.id, + "author": ds.author, + "downloads": ds.downloads, + "likes": ds.likes, + "last_modified": str(ds.last_modified) if ds.last_modified else None, + "tags": ds.tags if hasattr(ds, 'tags') else [], + }) + + return results + + except ImportError: + return [{"error": "huggingface_hub not installed"}] + except Exception as e: + return [{"error": str(e)}] + + +def discover_live(dataset_id: str = None) -> Dict[str, Any]: + """ + Get LIVE activity on a dataset - who's observing right now? + + Uses HuggingFace's commit history as activity feed. + + Returns: + Recent commits/updates to the dataset + """ + dataset_id = dataset_id or CENTRAL_DATASET + + try: + from huggingface_hub import HfApi + api = HfApi() + + # Get recent commits + commits = api.list_repo_commits( + repo_id=dataset_id, + repo_type="dataset", + ) + + # Take last 20 + recent = [] + for i, commit in enumerate(commits): + if i >= 20: + break + recent.append({ + "commit_id": commit.commit_id, + "created_at": str(commit.created_at), + "title": commit.title, + "authors": [a for a in (commit.authors or [])], + }) + + return { + "dataset": dataset_id, + "recent_commits": recent, + "total_commits": len(list(commits)) if hasattr(commits, '__len__') else "unknown", + } + + except ImportError: + return {"error": "huggingface_hub not installed"} + except Exception as e: + return {"error": str(e)} + + +def dataset_info(dataset_id: str = None) -> Dict[str, Any]: + """ + Get full metadata for a dataset from HuggingFace. + + Returns: + Dataset metadata including size, downloads, etc. + """ + dataset_id = dataset_id or CENTRAL_DATASET + + try: + from huggingface_hub import HfApi + api = HfApi() + + info = api.dataset_info(dataset_id) + + return { + "id": info.id, + "author": info.author, + "downloads": info.downloads, + "likes": info.likes, + "created_at": str(info.created_at) if info.created_at else None, + "last_modified": str(info.last_modified) if info.last_modified else None, + "private": info.private, + "tags": info.tags, + "card_data": info.card_data if hasattr(info, 'card_data') else None, + } + + except ImportError: + return {"error": "huggingface_hub not installed"} + except Exception as e: + return {"error": str(e)} + + +def stats() -> Dict[str, Any]: + """Get store statistics.""" + store = _get_store() + + conn = sqlite3.connect(store.db_path) + + total = conn.execute("SELECT COUNT(*) FROM observations").fetchone()[0] + pinned = conn.execute("SELECT COUNT(*) FROM observations WHERE pinned = 1").fetchone()[0] + + models = conn.execute( + "SELECT model_id, COUNT(*) FROM observations GROUP BY model_id" + ).fetchall() + + conn.close() + + return { + "total_observations": total, + "pinned_observations": pinned, + "models": {m: c for m, c in models}, + "genesis_root": get_genesis_root(), + "lattice_dir": str(_get_store().lattice_dir), + } + + +# ============================================================================= +# CLI +# ============================================================================= + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="CASCADE Store") + parser.add_argument("command", choices=["stats", "query", "get", "push", "pull", "sync"]) + parser.add_argument("--model", help="Model ID filter") + parser.add_argument("--cid", help="CID to fetch") + parser.add_argument("--limit", type=int, default=10) + parser.add_argument("--dataset", help="HF dataset (default: jtwspace/cascade-observations)") + + args = parser.parse_args() + + if args.command == "stats": + s = stats() + print(f"Total: {s['total_observations']}") + print(f"Synced: {s['pinned_observations']}") + print(f"Genesis: {s['genesis_root']}") + print(f"Central dataset: {CENTRAL_DATASET}") + if USER_DATASET: + print(f"User dataset: {USER_DATASET}") + print(f"Models:") + for model, count in s["models"].items(): + print(f" {model}: {count}") + + elif args.command == "query": + receipts = query(model_id=args.model, limit=args.limit) + for r in receipts: + print(f"{r.cid[:20]}... | {r.model_id} | {r.data}") + + elif args.command == "get": + if not args.cid: + print("--cid required") + else: + r = get(args.cid) + if r: + print(json.dumps(r.to_dict(), indent=2, default=str)) + else: + print("Not found") + + elif args.command == "push": + print(f"Pushing to HuggingFace: {CENTRAL_DATASET}") + result = sync_all() + print(f"Synced {result['synced']}, Failed {result['failed']}") + + elif args.command == "pull": + dataset = args.dataset or CENTRAL_DATASET + print(f"Pulling from: {dataset}") + count = pull_from_hf(dataset) + print(f"Pulled {count} observations") + + elif args.command == "sync": + # Bidirectional sync + print(f"Syncing with HuggingFace...") + pulled = pull_from_hf() + result = sync_all() + print(f"Pulled {pulled}, Pushed {result['synced']}") diff --git a/cascade/system/__init__.py b/cascade/system/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5078c470af191a1aece6ce5960167874356714a6 --- /dev/null +++ b/cascade/system/__init__.py @@ -0,0 +1,146 @@ +""" +CASCADE System Observatory - Universal Log Visualization + +Parse logs from ANY system and visualize them in CASCADE's state-space topology. +Systems produce logs. CASCADE reveals their soul. + +Causation is derived from explicit parent_hash chains in the data - +no inference or ML models needed. Your forensics pipeline builds the chains, +System Observatory just reads and visualizes them. + +Supported formats (all handled by UniversalAdapter): +- JSON/JSONL at any nesting depth +- Apache/Nginx access logs +- Kubernetes events +- Syslog format +- Generic timestamped logs +- Custom regex patterns +- ANY format with timestamps and messages + +Supported file types (via file_extractors): +- Text: .log, .txt, .json, .jsonl, .xml, .yaml +- Tabular: .csv, .tsv, .parquet, .xlsx +- Compressed: .gz, .zip, .tar, .tar.gz, .bz2 +- Documents: .pdf +- Databases: .sqlite, .db +- Binary: .evtx (Windows Event Log) +""" + +from cascade.system.adapter import ( + LogAdapter, + UniversalAdapter, # The one adapter to rule them all + JSONLAdapter, + ApacheLogAdapter, + NginxLogAdapter, + KubernetesLogAdapter, + GenericLogAdapter, + RegexAdapter, + auto_detect_adapter, + detect_data_type, # Detect logs vs dataset +) + +from cascade.system.observer import ( + SystemObserver, + observe_log_file, + observe_log_stream, +) + +from cascade.system.file_extractors import ( + extract_from_file, + extract_from_bytes, + get_extractor_for_file, + get_supported_extensions, + get_supported_formats, + ExtractionResult, + # Individual extractors + TextExtractor, + JSONExtractor, + CSVExtractor, + ParquetExtractor, + ExcelExtractor, + PDFExtractor, + XMLExtractor, + YAMLExtractor, + GzipExtractor, + ZipExtractor, + TarExtractor, + SQLiteExtractor, +) + +# ═══════════════════════════════════════════════════════════════════════════════ +# MoE Analyzer - DEPRECATED +# Kept for backwards compatibility but not used by System Observatory. +# Causation is now derived directly from parent_hash chains in the data. +# ═══════════════════════════════════════════════════════════════════════════════ +try: + from cascade.system.moe_analyzer import ( + MoEAnalyzer, + MoEAnalysisResult, + SystemClassifier, + TopologyClassification, + SystemTopology, + BaseSpecialist, + MLTrainingSpecialist, + WebServiceSpecialist, + MicroservicesSpecialist, + GenericSpecialist, + AnalysisInsight, + SpecialistAnalysis, + ) + _MOE_AVAILABLE = True +except ImportError: + _MOE_AVAILABLE = False + +__all__ = [ + # Adapters + "LogAdapter", + "UniversalAdapter", # Future-proof default + "JSONLAdapter", + "ApacheLogAdapter", + "NginxLogAdapter", + "KubernetesLogAdapter", + "GenericLogAdapter", + "RegexAdapter", + "auto_detect_adapter", + "detect_data_type", # Logs vs dataset detection + # Observer + "SystemObserver", + "observe_log_file", + "observe_log_stream", + # File Extractors + "extract_from_file", + "extract_from_bytes", + "get_extractor_for_file", + "get_supported_extensions", + "get_supported_formats", + "ExtractionResult", + "TextExtractor", + "JSONExtractor", + "CSVExtractor", + "ParquetExtractor", + "ExcelExtractor", + "PDFExtractor", + "XMLExtractor", + "YAMLExtractor", + "GzipExtractor", + "ZipExtractor", + "TarExtractor", + "SQLiteExtractor", +] + +# Add MoE exports only if available (deprecated but kept for compatibility) +if _MOE_AVAILABLE: + __all__.extend([ + "MoEAnalyzer", + "MoEAnalysisResult", + "SystemClassifier", + "TopologyClassification", + "SystemTopology", + "BaseSpecialist", + "MLTrainingSpecialist", + "WebServiceSpecialist", + "MicroservicesSpecialist", + "GenericSpecialist", + "AnalysisInsight", + "SpecialistAnalysis", + ]) diff --git a/cascade/system/adapter.py b/cascade/system/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..b2eac7ec1aae0c2e8615d1870357395d7a1190fb --- /dev/null +++ b/cascade/system/adapter.py @@ -0,0 +1,1409 @@ +""" +CASCADE System Observatory - Log Adapters + +Transform any log format into CASCADE events. +Each adapter parses a specific log format and emits standardized events. + +The key insight: all logs are just events with timestamps, components, and data. +CASCADE doesn't care WHERE the events come from - it visualizes causation. + +Enhanced with: +- drain3: IBM's template mining for auto-discovering log structure +- dateparser: Universal timestamp parsing for any format +""" + +import re +import json +import hashlib +import time +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, List, Generator, Tuple +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path + +# Universal parsing libraries +try: + import dateparser + HAS_DATEPARSER = True +except ImportError: + HAS_DATEPARSER = False + +try: + from drain3 import TemplateMiner + from drain3.template_miner_config import TemplateMinerConfig + HAS_DRAIN3 = True +except ImportError: + HAS_DRAIN3 = False + + +@dataclass +class ParsedEvent: + """Standardized event parsed from any log format.""" + timestamp: float + event_type: str + component: str + data: Dict[str, Any] + raw_line: str = "" + + # Hash chain for provenance + event_hash: str = field(default="") + parent_hash: str = field(default="") + + def __post_init__(self): + if not self.event_hash: + self.event_hash = self._compute_hash() + + def _compute_hash(self) -> str: + """Compute deterministic hash of this event.""" + content = json.dumps({ + "ts": self.timestamp, + "type": self.event_type, + "component": self.component, + "data": self.data, + "parent": self.parent_hash, + }, sort_keys=True, default=str) + return hashlib.sha256(content.encode()).hexdigest()[:16] + + def to_cascade_event(self) -> Dict[str, Any]: + """Convert to CASCADE event format for visualization.""" + return { + "event_id": f"sys_{self.event_hash}", + "timestamp": self.timestamp, + "event_type": self.event_type, + "component": self.component, + "data": { + **self.data, + "hash": self.event_hash, + "parent_hash": self.parent_hash, + }, + } + + +class LogAdapter(ABC): + """ + Base class for log adapters. + + Implement parse_line() to convert your log format to ParsedEvent. + """ + + name: str = "base" + description: str = "Base log adapter" + + def __init__(self): + self.event_count = 0 + self.last_hash = "" + + @abstractmethod + def parse_line(self, line: str) -> Optional[ParsedEvent]: + """ + Parse a single log line. + + Args: + line: Raw log line + + Returns: + ParsedEvent if successfully parsed, None to skip this line + """ + pass + + def parse_file(self, filepath: str) -> Generator[ParsedEvent, None, None]: + """Parse all lines in a file.""" + with open(filepath, "r", encoding="utf-8", errors="ignore") as f: + for line in f: + event = self.parse_line(line.strip()) + if event: + event.parent_hash = self.last_hash + event.event_hash = event._compute_hash() + self.last_hash = event.event_hash + self.event_count += 1 + yield event + + def parse_lines(self, lines: List[str]) -> Generator[ParsedEvent, None, None]: + """Parse a list of lines.""" + for line in lines: + event = self.parse_line(line.strip()) + if event: + event.parent_hash = self.last_hash + event.event_hash = event._compute_hash() + self.last_hash = event.event_hash + self.event_count += 1 + yield event + + +class UniversalAdapter(LogAdapter): + """ + THE UNIVERSAL ADAPTER - One parser to rule them all. + + Handles ANY log format through recursive field discovery: + - JSON at any nesting depth (CASCADE, ELK, Datadog, custom) + - Apache/Nginx access logs + - Kubernetes events + - Syslog format + - Generic timestamped text + - Raw text (fallback) + + The philosophy: logs are just events. Every line has: + - A timestamp (explicit or implicit from order) + - A source/component (explicit or inferred) + - A severity/type (explicit or inferred) + - A message (always present) + + This adapter finds these fields regardless of format. + """ + + name = "universal" + description = "Universal Log Parser - handles any format" + + # Field name variations (searched recursively in JSON) + TIMESTAMP_ALIASES = { + "timestamp", "time", "ts", "@timestamp", "datetime", "date", "t", + "created", "created_at", "logged_at", "event_time", "log_time", + "when", "epoch", "unix_time", "utc_time", "local_time" + } + + COMPONENT_ALIASES = { + "component", "service", "logger", "source", "module", "name", + "app", "application", "origin", "host", "hostname", "container", + "pod", "namespace", "class", "category", "tag", "facility" + } + + EVENT_TYPE_ALIASES = { + "event_type", "level", "severity", "type", "log_level", "loglevel", + "priority", "status", "kind", "action", "verb", "method" + } + + MESSAGE_ALIASES = { + "message", "msg", "text", "body", "content", "raw", "raw_message", + "description", "detail", "details", "info", "payload", "log" + } + + # Severity indicators (for inferring event type from text) + SEVERITY_PATTERNS = { + "critical": r'\b(CRITICAL|FATAL|EMERGENCY|PANIC)\b', + "error": r'\b(ERROR|ERR|EXCEPTION|FAIL(ED|URE)?)\b', + "warning": r'\b(WARN(ING)?|CAUTION|ALERT)\b', + "info": r'\b(INFO|NOTICE|LOG)\b', + "debug": r'\b(DEBUG|TRACE|VERBOSE)\b', + } + + # Timestamp regex patterns (ordered by specificity) + TIMESTAMP_PATTERNS = [ + # ISO 8601 variants + (r'(\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?)', 'iso'), + # Unix timestamp (float or int) + (r'\b(1[5-9]\d{8}(?:\.\d+)?)\b', 'unix'), # 1500000000+ range + # Apache/Nginx format + (r'\[(\d{2}/\w{3}/\d{4}:\d{2}:\d{2}:\d{2}(?:\s*[+-]\d{4})?)\]', 'apache'), + # Syslog format + (r'^(\w{3}\s+\d{1,2}\s+\d{2}:\d{2}:\d{2})', 'syslog'), + # Common date formats + (r'(\d{4}/\d{2}/\d{2}\s+\d{2}:\d{2}:\d{2})', 'slash'), + (r'(\d{2}-\d{2}-\d{4}\s+\d{2}:\d{2}:\d{2})', 'us'), + ] + + # Component extraction patterns (handle leading whitespace after timestamp removal) + COMPONENT_PATTERNS = [ + # [component] or (component) - with optional leading whitespace + r'^\s*\[([^\]]+)\]', + r'^\s*\(([^\)]+)\)', + # component: at start (with optional leading whitespace) + r'^\s*([A-Za-z][\w\-\.]+):', + # Kubernetes style: namespace/pod + r'\b([a-z][\w\-]+/[a-z][\w\-]+)\b', + # Docker container ID + r'\b([a-f0-9]{12})\b', + ] + + # Common delimiters to auto-detect (ordered by specificity) + DELIMITERS = [' | ', ' - ', '\t', ' :: ', ' -- '] + + def __init__(self): + super().__init__() + self._line_number = 0 + self._base_time = time.time() + self._detected_delimiter = None + self._detected_format = None # Cached format info after learning + self._sample_lines = [] # Collect lines for format learning + self._learning_complete = False + + # Initialize drain3 template miner if available + self._template_miner = None + if HAS_DRAIN3: + config = TemplateMinerConfig() + # drain3 config is ready to use without explicit load + self._template_miner = TemplateMiner(config=config) + + def parse_lines(self, lines: List[str]) -> Generator[ParsedEvent, None, None]: + """ + Parse lines with upfront format learning. + Override base class to learn format from first N lines before parsing any. + """ + # Learn format from first 50 lines (or all if less) + sample = [l.strip() for l in lines[:50] if l.strip() and not l.strip().startswith('{')] + if sample and not self._learning_complete: + self._learn_format(sample) + + # Now parse all lines with learned format + for line in lines: + event = self.parse_line(line.strip()) + if event: + event.parent_hash = self.last_hash + event.event_hash = event._compute_hash() + self.last_hash = event.event_hash + self.event_count += 1 + yield event + + def parse_file(self, filepath: str) -> Generator[ParsedEvent, None, None]: + """ + Parse file with upfront format learning. + Override base class to learn format before yielding events. + """ + with open(filepath, "r", encoding="utf-8", errors="ignore") as f: + # Read first batch to learn format + sample_lines = [] + all_lines = [] + for i, line in enumerate(f): + all_lines.append(line) + if i < 50 and line.strip() and not line.strip().startswith('{'): + sample_lines.append(line.strip()) + + # Learn format + if sample_lines and not self._learning_complete: + self._learn_format(sample_lines) + + # Re-read and parse with learned format + with open(filepath, "r", encoding="utf-8", errors="ignore") as f: + for line in f: + event = self.parse_line(line.strip()) + if event: + event.parent_hash = self.last_hash + event.event_hash = event._compute_hash() + self.last_hash = event.event_hash + self.event_count += 1 + yield event + + def _learn_format(self, lines: List[str]) -> None: + """ + Learn log format from sample lines using statistical analysis. + Detects: delimiter, field positions, timestamp format. + """ + if self._learning_complete or len(lines) < 3: + return + + # Count delimiter occurrences across lines + delimiter_counts = {d: 0 for d in self.DELIMITERS} + for line in lines[:50]: # Sample first 50 lines + for delim in self.DELIMITERS: + count = line.count(delim) + if count >= 2: # At least 3 fields + delimiter_counts[delim] += count + + # Find most common delimiter + best_delim = max(delimiter_counts, key=delimiter_counts.get) + if delimiter_counts[best_delim] > len(lines) * 2: # Significant presence + self._detected_delimiter = best_delim + self._analyze_delimited_format(lines, best_delim) + + self._learning_complete = True + + def _analyze_delimited_format(self, lines: List[str], delimiter: str) -> None: + """ + Analyze field positions in delimiter-separated log format. + Learns which field contains timestamp, level, component, message. + """ + # Split sample lines + field_samples = [] + for line in lines[:20]: + parts = [p.strip() for p in line.split(delimiter)] + if len(parts) >= 3: + field_samples.append(parts) + + if not field_samples: + return + + # Analyze each field position + num_fields = min(len(s) for s in field_samples) + format_info = { + 'delimiter': delimiter, + 'timestamp_idx': None, + 'level_idx': None, + 'component_idx': None, + 'message_idx': num_fields - 1, # Default: last field is message + } + + LEVEL_KEYWORDS = {'DEBUG', 'INFO', 'WARNING', 'WARN', 'ERROR', 'CRITICAL', 'FATAL', 'TRACE'} + + for idx in range(num_fields): + field_values = [s[idx] if idx < len(s) else '' for s in field_samples] + + # Check if this field contains timestamps + timestamp_score = sum(1 for v in field_values if self._looks_like_timestamp(v)) + if timestamp_score > len(field_values) * 0.7: + format_info['timestamp_idx'] = idx + continue + + # Check if this field contains log levels + level_score = sum(1 for v in field_values if v.upper() in LEVEL_KEYWORDS) + if level_score > len(field_values) * 0.5: + format_info['level_idx'] = idx + continue + + # Check if this field looks like component names + # Components are typically: lowercase, underscores/dots, consistent format + component_score = sum(1 for v in field_values + if v and re.match(r'^[a-zA-Z][\w\.\-]*$', v) + and v.upper() not in LEVEL_KEYWORDS + and not self._looks_like_timestamp(v)) + if component_score > len(field_values) * 0.5 and format_info['component_idx'] is None: + format_info['component_idx'] = idx + + self._detected_format = format_info + + def _looks_like_timestamp(self, value: str) -> bool: + """Check if a string looks like a timestamp.""" + if not value: + return False + # Check for time-like patterns: digits with : or - or . + if re.match(r'^\d{1,4}[-/:]\d{1,2}[-/:T]', value): + return True + if re.match(r'^\d{2}:\d{2}:\d{2}', value): + return True + # Unix timestamp + if re.match(r'^1[5-9]\d{8}', value): + return True + return False + + def parse_line(self, line: str) -> Optional[ParsedEvent]: + """ + Universal parse - handles any format with intelligent auto-detection. + """ + if not line or not line.strip(): + return None + + line = line.strip() + self._line_number += 1 + + # Collect samples for format learning + if not self._learning_complete and len(self._sample_lines) < 50: + self._sample_lines.append(line) + if len(self._sample_lines) >= 10: + self._learn_format(self._sample_lines) + + # Try JSON first (handles all structured formats) + if line.startswith('{'): + try: + obj = json.loads(line) + return self._parse_json(obj, line) + except json.JSONDecodeError: + pass + + # Use learned delimited format if detected + if self._detected_format: + return self._parse_delimited(line) + + # Fall back to traditional text parsing + return self._parse_text(line) + + def _parse_delimited(self, line: str) -> ParsedEvent: + """ + Parse line using auto-detected delimiter format. + This is the intelligent parsing path for structured text logs. + """ + fmt = self._detected_format + parts = [p.strip() for p in line.split(fmt['delimiter'])] + + # Extract fields by learned positions + timestamp = None + level = 'info' + component = 'system' + message = line + + if fmt['timestamp_idx'] is not None and fmt['timestamp_idx'] < len(parts): + ts_str = parts[fmt['timestamp_idx']] + timestamp = self._parse_timestamp_universal(ts_str) + + if fmt['level_idx'] is not None and fmt['level_idx'] < len(parts): + level = self._normalize_event_type(parts[fmt['level_idx']]) + + if fmt['component_idx'] is not None and fmt['component_idx'] < len(parts): + component = parts[fmt['component_idx']] + + if fmt['message_idx'] is not None and fmt['message_idx'] < len(parts): + # Message is everything from message_idx onwards (may be split by delimiter) + msg_start = fmt['message_idx'] + message = fmt['delimiter'].join(parts[msg_start:]) + + if timestamp is None: + timestamp = self._base_time + (self._line_number * 0.001) + + # Feed to drain3 for template mining (if available) + if self._template_miner: + result = self._template_miner.add_log_message(message) + # Could use result.get_template() for pattern analysis + + return ParsedEvent( + timestamp=timestamp, + event_type=level, + component=component, + data={"message": message}, + raw_line=line, + ) + + def _parse_timestamp_universal(self, ts_str: str) -> Optional[float]: + """ + Parse any timestamp format using dateparser (if available) or fallback regex. + """ + if not ts_str: + return None + + # Try dateparser first (handles almost any format) + if HAS_DATEPARSER: + try: + parsed = dateparser.parse(ts_str, settings={ + 'RETURN_AS_TIMEZONE_AWARE': False, + 'PREFER_DATES_FROM': 'past', + }) + if parsed: + return parsed.timestamp() + except Exception: + pass + + # Fallback to existing patterns + return self._parse_timestamp_string(ts_str, 'auto') + + def _parse_json(self, obj: dict, raw_line: str) -> ParsedEvent: + """ + Parse JSON object with recursive field discovery. + Handles any nesting depth - finds fields wherever they are. + """ + # Recursively search for known fields + timestamp = self._find_field(obj, self.TIMESTAMP_ALIASES) + component = self._find_field(obj, self.COMPONENT_ALIASES) + event_type = self._find_field(obj, self.EVENT_TYPE_ALIASES) + message = self._find_field(obj, self.MESSAGE_ALIASES) + + # Parse timestamp + if timestamp is not None: + ts = self._parse_timestamp_value(timestamp) + else: + ts = self._base_time + (self._line_number * 0.001) + + # Normalize event type + if event_type: + event_type = self._normalize_event_type(str(event_type)) + else: + # Infer from message content + event_type = self._infer_severity(str(message) if message else str(obj)) + + # Default component + if not component: + component = "system" + else: + component = str(component) + + # Build data dict + data = self._flatten_to_data(obj, message) + + return ParsedEvent( + timestamp=ts, + event_type=event_type, + component=component, + data=data, + raw_line=raw_line, + ) + + def _find_field(self, obj: Any, aliases: set, depth: int = 0) -> Any: + """ + Recursively search for a field by any of its aliases. + Returns the first match found (breadth-first within each level). + """ + if depth > 10: # Prevent infinite recursion + return None + + if isinstance(obj, dict): + # Check this level first + for key in obj: + if key.lower() in aliases: + return obj[key] + + # Then recurse into nested dicts + for key, value in obj.items(): + if isinstance(value, dict): + result = self._find_field(value, aliases, depth + 1) + if result is not None: + return result + elif isinstance(value, list) and value and isinstance(value[0], dict): + # Check first item of list of dicts + result = self._find_field(value[0], aliases, depth + 1) + if result is not None: + return result + + return None + + def _flatten_to_data(self, obj: dict, message: Any = None) -> Dict[str, Any]: + """ + Flatten JSON object to data dict for visualization. + Preserves important nested structures while making data accessible. + """ + data = {} + + # Set message + if message is not None: + if isinstance(message, dict): + data.update(message) + else: + data["message"] = str(message) + + # Extract key fields at any level + for key, value in obj.items(): + key_lower = key.lower() + + # Skip already processed fields + if key_lower in self.TIMESTAMP_ALIASES | self.MESSAGE_ALIASES: + continue + + # Include scalar values directly + if isinstance(value, (str, int, float, bool)) or value is None: + data[key] = value + # Include small dicts inline + elif isinstance(value, dict) and len(value) <= 5: + for k, v in value.items(): + if isinstance(v, (str, int, float, bool)): + data[f"{key}.{k}"] = v + # Summarize large structures + elif isinstance(value, dict): + data[f"_{key}_keys"] = list(value.keys())[:10] + elif isinstance(value, list): + data[f"_{key}_count"] = len(value) + + return data + + def _parse_text(self, line: str) -> ParsedEvent: + """ + Parse unstructured text log line. + Extracts timestamp, component, severity from text patterns. + """ + timestamp = None + component = "system" + remaining = line + + # Try to extract timestamp + for pattern, fmt in self.TIMESTAMP_PATTERNS: + match = re.search(pattern, line, re.IGNORECASE) + if match: + ts_str = match.group(1) + timestamp = self._parse_timestamp_string(ts_str, fmt) + # Remove timestamp from remaining text + remaining = line[:match.start()] + line[match.end():] + break + + if timestamp is None: + timestamp = self._base_time + (self._line_number * 0.001) + + # Try to extract component + SEVERITY_WORDS = {'error', 'err', 'warn', 'warning', 'info', 'debug', 'trace', 'fatal', 'critical'} + for pattern in self.COMPONENT_PATTERNS: + match = re.search(pattern, remaining) + if match: + candidate = match.group(1) + # Don't treat severity keywords as components + if candidate.lower() not in SEVERITY_WORDS: + component = candidate + remaining = remaining[match.end():].strip() + break + + # Infer severity from text + event_type = self._infer_severity(line) + + # Clean up message + message = remaining.strip() + if message.startswith(':'): + message = message[1:].strip() + + return ParsedEvent( + timestamp=timestamp, + event_type=event_type, + component=component, + data={"message": message}, + raw_line=line, + ) + + def _parse_timestamp_value(self, value: Any) -> float: + """Parse timestamp from any format.""" + if isinstance(value, (int, float)): + # Unix timestamp - check if milliseconds + if value > 1e12: + return value / 1000 + return float(value) + + if isinstance(value, str): + return self._parse_timestamp_string(value, 'auto') + + return self._base_time + (self._line_number * 0.001) + + def _parse_timestamp_string(self, ts_str: str, fmt: str) -> float: + """Parse timestamp string to Unix timestamp.""" + try: + if fmt == 'unix' or (fmt == 'auto' and ts_str.replace('.', '').isdigit()): + val = float(ts_str) + return val / 1000 if val > 1e12 else val + + if fmt == 'iso' or fmt == 'auto': + # ISO 8601 + try: + dt = datetime.fromisoformat(ts_str.replace('Z', '+00:00')) + return dt.timestamp() + except: + pass + + if fmt == 'apache': + # Apache: 10/Oct/2000:13:55:36 +0700 + dt = datetime.strptime(ts_str.split()[0], "%d/%b/%Y:%H:%M:%S") + return dt.timestamp() + + if fmt == 'syslog': + # Syslog: Oct 10 13:55:36 (no year) + current_year = datetime.now().year + dt = datetime.strptime(f"{current_year} {ts_str}", "%Y %b %d %H:%M:%S") + return dt.timestamp() + + # Try common formats + for date_fmt in [ + "%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S", + "%Y/%m/%d %H:%M:%S", "%d-%m-%Y %H:%M:%S", + ]: + try: + dt = datetime.strptime(ts_str.split('+')[0].split('Z')[0], date_fmt) + return dt.timestamp() + except: + continue + except: + pass + + return self._base_time + (self._line_number * 0.001) + + def _normalize_event_type(self, event_type: str) -> str: + """Normalize event type to standard values.""" + et = event_type.lower().strip() + + # Map common variations + if et in ('err', 'exception', 'fail', 'failed', 'failure', 'fatal', 'critical', 'emergency', 'panic'): + return 'error' + if et in ('warn', 'caution', 'alert'): + return 'warning' + if et in ('information', 'notice', 'log'): + return 'info' + if et in ('trace', 'verbose', 'fine', 'finer', 'finest'): + return 'debug' + if et in ('state_change', 'transition', 'change'): + return 'state_change' + if et in ('checkpoint', 'save', 'snapshot'): + return 'checkpoint' + if et in ('progress', 'step', 'iteration'): + return 'progress' + if et in ('config', 'configuration', 'setting', 'setup'): + return 'config' + if et in ('metric', 'measure', 'stat', 'stats'): + return 'metric' + if et in ('anomaly', 'outlier', 'unusual'): + return 'anomaly' + + return et if et else 'info' + + def _infer_severity(self, text: str) -> str: + """Infer severity/event type from text content.""" + for severity, pattern in self.SEVERITY_PATTERNS.items(): + if re.search(pattern, text, re.IGNORECASE): + return severity + return 'info' + + +class JSONLAdapter(LogAdapter): + """ + Parse JSON Lines format (one JSON object per line). + + Expected fields (flexible): + - timestamp/time/ts/@timestamp: Unix timestamp or ISO string + - level/severity/type: Event type (info, error, warning, etc.) + - component/service/logger/source: Which component + - message/msg/data: Event data + + Also supports CASCADE's nested format: + - {"event": {"timestamp": ..., "component": ..., "event_type": ...}, "metrics": {...}, "triage": {...}} + """ + + name = "jsonl" + description = "JSON Lines (structured logs)" + + TIMESTAMP_FIELDS = ["timestamp", "time", "ts", "@timestamp", "datetime", "date"] + LEVEL_FIELDS = ["level", "severity", "type", "log_level", "loglevel", "event_type"] + COMPONENT_FIELDS = ["component", "service", "logger", "source", "module", "name"] + MESSAGE_FIELDS = ["message", "msg", "data", "content", "text", "body", "raw_message", "raw"] + + def parse_line(self, line: str) -> Optional[ParsedEvent]: + if not line: + return None + + try: + obj = json.loads(line) + except json.JSONDecodeError: + return None + + # Check for CASCADE nested format: {"event": {...}, "metrics": {...}, "triage": {...}} + if "event" in obj and isinstance(obj["event"], dict): + return self._parse_cascade_format(obj) + + # Standard JSONL parsing + return self._parse_standard_format(obj, line) + + def _parse_cascade_format(self, obj: dict) -> Optional[ParsedEvent]: + """Parse CASCADE's nested tape format.""" + evt = obj["event"] + + # Direct field extraction from CASCADE format + timestamp = evt.get("timestamp", time.time()) + event_type = evt.get("event_type", "info") + component = evt.get("component", "system") + + # Build data from CASCADE event + data = {} + + # Get raw message + if "raw" in evt: + data["message"] = evt["raw"] + elif "data" in evt and isinstance(evt["data"], dict): + if "raw_message" in evt["data"]: + data["message"] = evt["data"]["raw_message"] + data.update(evt["data"]) + + # Include event_id if present + if "event_id" in evt: + data["event_id"] = evt["event_id"] + + # Include metrics summary if present (for visualization) + if "metrics" in obj and isinstance(obj["metrics"], dict): + metrics = obj["metrics"] + if "event_count" in metrics: + data["_event_count"] = metrics["event_count"] + if "health_status" in metrics: + data["_health"] = metrics["health_status"].get("overall", "unknown") + + # Include triage status if present + if "triage" in obj and isinstance(obj["triage"], dict): + triage = obj["triage"] + data["_triage_status"] = triage.get("status", "UNKNOWN") + data["_triage_action"] = triage.get("action", "") + + return ParsedEvent( + timestamp=timestamp, + event_type=event_type, + component=component, + data=data, + raw_line=json.dumps(obj), + ) + + def _parse_standard_format(self, obj: dict, line: str) -> Optional[ParsedEvent]: + """Parse standard JSONL format.""" + # Extract timestamp + timestamp = None + for field in self.TIMESTAMP_FIELDS: + if field in obj: + timestamp = self._parse_timestamp(obj[field]) + break + if timestamp is None: + timestamp = time.time() + + # Extract event type + event_type = "info" + for field in self.LEVEL_FIELDS: + if field in obj: + event_type = str(obj[field]).lower() + break + + # Extract component + component = "system" + for field in self.COMPONENT_FIELDS: + if field in obj: + component = str(obj[field]) + break + + # Extract message/data + data = {} + for field in self.MESSAGE_FIELDS: + if field in obj: + msg = obj[field] + if isinstance(msg, dict): + data.update(msg) + else: + data["message"] = str(msg) + break + + # Include all other fields in data + for k, v in obj.items(): + if k not in self.TIMESTAMP_FIELDS + self.LEVEL_FIELDS + self.COMPONENT_FIELDS + self.MESSAGE_FIELDS: + data[k] = v + + return ParsedEvent( + timestamp=timestamp, + event_type=event_type, + component=component, + data=data, + raw_line=line, + ) + + def _parse_timestamp(self, value: Any) -> float: + """Parse various timestamp formats to Unix timestamp.""" + if isinstance(value, (int, float)): + # Already numeric - check if milliseconds + if value > 1e12: + return value / 1000 + return value + + if isinstance(value, str): + # Try ISO format + try: + dt = datetime.fromisoformat(value.replace("Z", "+00:00")) + return dt.timestamp() + except: + pass + + # Try common formats + for fmt in [ + "%Y-%m-%dT%H:%M:%S.%f", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%d %H:%M:%S.%f", + "%Y-%m-%d %H:%M:%S", + "%d/%b/%Y:%H:%M:%S", + ]: + try: + dt = datetime.strptime(value.split("+")[0].split("-")[0] if "+" in value else value, fmt) + return dt.timestamp() + except: + continue + + return time.time() + + +class ApacheLogAdapter(LogAdapter): + """ + Parse Apache Combined Log Format. + + Format: %h %l %u %t "%r" %>s %b "%{Referer}i" "%{User-agent}i" + Example: 127.0.0.1 - - [10/Oct/2000:13:55:36 -0700] "GET /apache_pb.gif HTTP/1.0" 200 2326 "http://www.example.com/start.html" "Mozilla/4.08 [en] (Win98; I ;Nav)" + """ + + name = "apache" + description = "Apache Combined Log Format" + + # Regex for Apache combined format + PATTERN = re.compile( + r'^(?P[\d\.]+)\s+' + r'(?P\S+)\s+' + r'(?P\S+)\s+' + r'\[(?P[^\]]+)\]\s+' + r'"(?P\w+)\s+(?P\S+)\s+(?P[^"]+)"\s+' + r'(?P\d+)\s+' + r'(?P\S+)' + r'(?:\s+"(?P[^"]*)"\s+"(?P[^"]*)")?' + ) + + def parse_line(self, line: str) -> Optional[ParsedEvent]: + if not line: + return None + + match = self.PATTERN.match(line) + if not match: + return None + + d = match.groupdict() + + # Parse timestamp [10/Oct/2000:13:55:36 -0700] + try: + ts_str = d["timestamp"].split()[0] # Remove timezone + dt = datetime.strptime(ts_str, "%d/%b/%Y:%H:%M:%S") + timestamp = dt.timestamp() + except: + timestamp = time.time() + + # Determine event type by status code + status = int(d.get("status", 200)) + if status >= 500: + event_type = "error" + elif status >= 400: + event_type = "warning" + elif status >= 300: + event_type = "redirect" + else: + event_type = "request" + + return ParsedEvent( + timestamp=timestamp, + event_type=event_type, + component=f"http:{d.get('method', 'GET')}", + data={ + "ip": d.get("ip"), + "method": d.get("method"), + "path": d.get("path"), + "status": status, + "size": int(d.get("size", 0)) if d.get("size", "-") != "-" else 0, + "referer": d.get("referer", ""), + "user_agent": d.get("useragent", ""), + }, + raw_line=line, + ) + + +class NginxLogAdapter(ApacheLogAdapter): + """ + Parse Nginx access logs (same format as Apache combined by default). + """ + + name = "nginx" + description = "Nginx Access Log Format" + + +class KubernetesLogAdapter(LogAdapter): + """ + Parse Kubernetes events and pod logs. + + Handles: + - kubectl get events output + - Pod log format: timestamp stdout/stderr F message + - JSON structured logs from pods + """ + + name = "kubernetes" + description = "Kubernetes Events & Pod Logs" + + # Pod log pattern: 2024-01-01T00:00:00.000000000Z stdout F message + POD_LOG_PATTERN = re.compile( + r'^(?P\d{4}-\d{2}-\d{2}T[\d:.]+Z?)\s+' + r'(?Pstdout|stderr)\s+' + r'(?P\S+)\s+' + r'(?P.*)$' + ) + + # Kubectl events pattern + EVENT_PATTERN = re.compile( + r'^(?P\S+)\s+' + r'(?P\S+)\s+' + r'(?P\S+)\s+' + r'(?P\S+)\s+' + r'(?P.*)$' + ) + + def parse_line(self, line: str) -> Optional[ParsedEvent]: + if not line: + return None + + # Try JSON first (structured pod logs) + if line.startswith("{"): + jsonl = JSONLAdapter() + result = jsonl.parse_line(line) + if result: + result.component = f"k8s:{result.component}" + return result + + # Try pod log format + match = self.POD_LOG_PATTERN.match(line) + if match: + d = match.groupdict() + try: + dt = datetime.fromisoformat(d["timestamp"].replace("Z", "+00:00")) + timestamp = dt.timestamp() + except: + timestamp = time.time() + + return ParsedEvent( + timestamp=timestamp, + event_type="error" if d["stream"] == "stderr" else "log", + component=f"k8s:pod", + data={ + "stream": d["stream"], + "message": d["message"], + }, + raw_line=line, + ) + + # Try kubectl events format + match = self.EVENT_PATTERN.match(line) + if match: + d = match.groupdict() + event_type = d.get("type", "Normal").lower() + if event_type == "warning": + event_type = "warning" + elif event_type == "normal": + event_type = "info" + + return ParsedEvent( + timestamp=time.time(), # Events don't have exact timestamp in this format + event_type=event_type, + component=f"k8s:{d.get('object', 'unknown').split('/')[0]}", + data={ + "reason": d.get("reason"), + "object": d.get("object"), + "message": d.get("message"), + }, + raw_line=line, + ) + + return None + + +class GenericLogAdapter(LogAdapter): + """ + Parse generic timestamped logs. + + Attempts to extract: + - Timestamp from beginning of line + - Log level (INFO, ERROR, WARN, DEBUG, etc.) + - Component name (often in brackets) + - Message + """ + + name = "generic" + description = "Generic Timestamped Logs" + + # Common timestamp patterns at start of line + TIMESTAMP_PATTERNS = [ + (re.compile(r'^(\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d+)?)'), "%Y-%m-%dT%H:%M:%S"), + (re.compile(r'^(\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2})'), "%Y/%m/%d %H:%M:%S"), + (re.compile(r'^(\w{3}\s+\d{1,2}\s+\d{2}:\d{2}:\d{2})'), None), # Syslog + (re.compile(r'^(\d{10,13})'), None), # Unix timestamp + ] + + # Level patterns + LEVEL_PATTERN = re.compile(r'\b(DEBUG|INFO|NOTICE|WARN(?:ING)?|ERROR|CRIT(?:ICAL)?|FATAL|SEVERE|TRACE)\b', re.I) + + # Component patterns (in brackets or before colon) + COMPONENT_PATTERN = re.compile(r'\[([^\]]+)\]|^[^:]+:\s*(\S+):') + + def parse_line(self, line: str) -> Optional[ParsedEvent]: + if not line: + return None + + timestamp = time.time() + remaining = line + + # Extract timestamp + for pattern, fmt in self.TIMESTAMP_PATTERNS: + match = pattern.match(line) + if match: + ts_str = match.group(1) + try: + if fmt: + dt = datetime.strptime(ts_str.split(".")[0].replace("T", " "), fmt.replace("T", " ")) + timestamp = dt.timestamp() + elif ts_str.isdigit(): + ts = int(ts_str) + timestamp = ts / 1000 if ts > 1e12 else ts + except: + pass + remaining = line[match.end():].strip() + break + + # Extract level + event_type = "info" + level_match = self.LEVEL_PATTERN.search(remaining) + if level_match: + level = level_match.group(1).upper() + if level in ("ERROR", "CRITICAL", "CRIT", "FATAL", "SEVERE"): + event_type = "error" + elif level in ("WARN", "WARNING"): + event_type = "warning" + elif level == "DEBUG": + event_type = "debug" + elif level == "TRACE": + event_type = "trace" + + # Extract component + component = "system" + comp_match = self.COMPONENT_PATTERN.search(remaining) + if comp_match: + component = comp_match.group(1) or comp_match.group(2) or "system" + + return ParsedEvent( + timestamp=timestamp, + event_type=event_type, + component=component, + data={"message": remaining}, + raw_line=line, + ) + + +class RegexAdapter(LogAdapter): + r""" + Parse logs using a custom regex pattern. + + The regex should have named groups: + - timestamp (optional): Timestamp string + - type/level (optional): Event type + - component (optional): Component name + - message (optional): Message or data + + Example pattern: + r'^(?P\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) (?P\w+) (?P\S+) (?P.*)$' + """ + + name = "regex" + description = "Custom Regex Pattern" + + def __init__(self, pattern: str, timestamp_format: str = None): + """ + Args: + pattern: Regex pattern with named groups + timestamp_format: strptime format for timestamp group + """ + super().__init__() + self.pattern = re.compile(pattern) + self.timestamp_format = timestamp_format + + def parse_line(self, line: str) -> Optional[ParsedEvent]: + if not line: + return None + + match = self.pattern.match(line) + if not match: + return None + + d = match.groupdict() + + # Parse timestamp + timestamp = time.time() + if "timestamp" in d and d["timestamp"]: + if self.timestamp_format: + try: + dt = datetime.strptime(d["timestamp"], self.timestamp_format) + timestamp = dt.timestamp() + except: + pass + else: + # Try to parse automatically + try: + dt = datetime.fromisoformat(d["timestamp"]) + timestamp = dt.timestamp() + except: + pass + + # Event type + event_type = d.get("type") or d.get("level") or "info" + event_type = event_type.lower() + + # Component + component = d.get("component") or d.get("source") or "system" + + # Message/data + message = d.get("message") or d.get("data") or "" + data = {"message": message} + + # Include any other captured groups + for k, v in d.items(): + if k not in ("timestamp", "type", "level", "component", "source", "message", "data") and v: + data[k] = v + + return ParsedEvent( + timestamp=timestamp, + event_type=event_type, + component=component, + data=data, + raw_line=line, + ) + + +def auto_detect_adapter(sample_lines: List[str]) -> LogAdapter: + """ + Auto-detect the best adapter for a set of sample lines. + + In practice, UniversalAdapter handles everything. + This function exists for backwards compatibility and edge cases. + + Args: + sample_lines: First N lines of the log file + + Returns: + Most suitable LogAdapter instance (usually UniversalAdapter) + """ + # UniversalAdapter handles everything - it's the future-proof choice + return UniversalAdapter() + + +def detect_log_format(filepath: str) -> str: + """ + Detect log format from file. + + Returns adapter name: 'universal', 'jsonl', 'apache', 'nginx', 'kubernetes', 'generic' + """ + try: + with open(filepath, "r", encoding="utf-8", errors="ignore") as f: + lines = [f.readline() for _ in range(20)] + + # Quick heuristic for reporting purposes + samples = [l.strip() for l in lines if l.strip()][:5] + if not samples: + return "universal" + + # Check if JSON + json_count = sum(1 for s in samples if s.startswith('{')) + if json_count >= len(samples) // 2: + return "jsonl" + + # Check for Apache format + if any(re.search(r'\[\d{2}/\w{3}/\d{4}:', s) for s in samples): + return "apache" + + # Check for Kubernetes + if any('"kind"' in s or 'namespace' in s.lower() for s in samples): + return "kubernetes" + + return "universal" + except: + return "universal" + + +def detect_data_type(lines: List[str]) -> Dict[str, Any]: + """ + Detect whether data looks like logs vs a dataset. + + Returns: + { + "type": "logs" | "dataset" | "mixed" | "unknown", + "confidence": 0.0-1.0, + "signals": ["what made us think this"], + "recommendation": "Use Log Observatory" | "Use Dataset Forensics" + } + """ + if not lines: + return {"type": "unknown", "confidence": 0.0, "signals": [], "recommendation": "No data"} + + # Sample up to 100 lines + samples = [l.strip() for l in lines[:100] if l.strip()] + if not samples: + return {"type": "unknown", "confidence": 0.0, "signals": [], "recommendation": "No data"} + + log_signals = [] + dataset_signals = [] + + # === LOG SIGNALS === + + # Timestamps in log format + timestamp_patterns = [ + r'\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}', # ISO + r'\[\d{2}/\w{3}/\d{4}:\d{2}:\d{2}:\d{2}', # Apache + r'\w{3}\s+\d{1,2}\s+\d{2}:\d{2}:\d{2}', # Syslog + ] + timestamp_matches = sum(1 for s in samples + for p in timestamp_patterns + if re.search(p, s)) + if timestamp_matches > len(samples) * 0.3: + log_signals.append(f"timestamps_found:{timestamp_matches}") + + # Log level indicators + log_levels = r'\b(DEBUG|INFO|WARN|WARNING|ERROR|CRITICAL|FATAL|TRACE)\b' + level_matches = sum(1 for s in samples if re.search(log_levels, s, re.IGNORECASE)) + if level_matches > len(samples) * 0.2: + log_signals.append(f"log_levels:{level_matches}") + + # Component patterns [component] or component: + component_pattern = r'(\[[\w.-]+\]|^[\w.-]+:)' + component_matches = sum(1 for s in samples if re.search(component_pattern, s)) + if component_matches > len(samples) * 0.2: + log_signals.append(f"components:{component_matches}") + + # JSON with event-like keys + event_keys = ['timestamp', 'time', 'ts', 'level', 'message', 'msg', 'component', + 'event', 'event_type', 'severity', 'logger', 'source'] + json_event_count = 0 + for s in samples: + if s.startswith('{'): + try: + obj = json.loads(s) + if any(k in obj for k in event_keys): + json_event_count += 1 + except: + pass + if json_event_count > len(samples) * 0.3: + log_signals.append(f"json_events:{json_event_count}") + + # HTTP methods/status codes + http_pattern = r'\b(GET|POST|PUT|DELETE|PATCH)\b|\s[1-5]\d{2}\s' + http_matches = sum(1 for s in samples if re.search(http_pattern, s)) + if http_matches > len(samples) * 0.1: + log_signals.append(f"http_traffic:{http_matches}") + + # === DATASET SIGNALS === + + # CSV-like structure (consistent columns) + if samples: + first = samples[0] + if ',' in first: + comma_counts = [s.count(',') for s in samples[:10]] + if len(set(comma_counts)) <= 2: # Consistent column count + dataset_signals.append(f"csv_structure:cols={comma_counts[0]+1}") + + # JSON with data-like keys (not event keys) + data_keys = ['id', 'name', 'title', 'text', 'content', 'label', 'category', + 'value', 'price', 'count', 'score', 'rating', 'description', + 'user', 'author', 'url', 'image', 'date', 'created'] + json_data_count = 0 + for s in samples: + if s.startswith('{'): + try: + obj = json.loads(s) + # Data if has data keys but NOT event keys + has_data = any(k in obj for k in data_keys) + has_event = any(k in obj for k in event_keys) + if has_data and not has_event: + json_data_count += 1 + except: + pass + if json_data_count > len(samples) * 0.3: + dataset_signals.append(f"json_data:{json_data_count}") + + # Long text content (datasets often have text fields) + long_text_count = sum(1 for s in samples if len(s) > 500) + if long_text_count > len(samples) * 0.2: + dataset_signals.append(f"long_text:{long_text_count}") + + # Numeric-heavy (datasets often have numbers) + numeric_heavy = sum(1 for s in samples if len(re.findall(r'\d+\.?\d*', s)) > 5) + if numeric_heavy > len(samples) * 0.5: + dataset_signals.append(f"numeric_data:{numeric_heavy}") + + # === DECISION === + + log_score = len(log_signals) + data_score = len(dataset_signals) + total = log_score + data_score + + if total == 0: + return { + "type": "unknown", + "confidence": 0.3, + "signals": ["no clear signals"], + "recommendation": "Try both - Logs tab and Dataset tab" + } + + if log_score > data_score * 2: + return { + "type": "logs", + "confidence": min(log_score / 5, 1.0), + "signals": log_signals, + "recommendation": "Use Log Observatory (Observe tab)" + } + elif data_score > log_score * 2: + return { + "type": "dataset", + "confidence": min(data_score / 4, 1.0), + "signals": dataset_signals, + "recommendation": "Use Dataset Forensics (Extract Ghost Log)" + } + else: + return { + "type": "mixed", + "confidence": 0.5, + "signals": log_signals + dataset_signals, + "recommendation": "Data has both log and dataset characteristics - try Logs first" + } + diff --git a/cascade/system/file_extractors.py b/cascade/system/file_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..330e61edc96a4209feec230c77bfdcb6c2fb8a43 --- /dev/null +++ b/cascade/system/file_extractors.py @@ -0,0 +1,900 @@ +""" +CASCADE System Observatory - File Format Extractors. + +Extract log data from various file formats: +- Text: .log, .txt, .jsonl, .json, .yaml, .xml +- Tabular: .csv, .tsv, .parquet, .xlsx, .xls +- Compressed: .gz, .zip, .tar, .tar.gz, .bz2 +- Documents: .pdf (text extraction) +- Databases: .sqlite, .db +- Binary logs: Windows Event Log (.evtx), systemd journal + +Each extractor converts its format into lines of text that +UniversalAdapter can then parse. +""" + +import io +import json +import gzip +import zipfile +import tarfile +import bz2 +import tempfile +from pathlib import Path +from typing import List, Iterator, Optional, Tuple, Union, BinaryIO +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +@dataclass +class ExtractionResult: + """Result of extracting log data from a file.""" + lines: List[str] + source_format: str + file_count: int = 1 # For archives + total_bytes: int = 0 + warnings: List[str] = None + + def __post_init__(self): + if self.warnings is None: + self.warnings = [] + + +class BaseExtractor(ABC): + """Base class for file format extractors.""" + + extensions: List[str] = [] + name: str = "base" + + @abstractmethod + def extract(self, file_path: str) -> ExtractionResult: + """Extract log lines from the file.""" + pass + + @abstractmethod + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + """Extract from raw bytes (for uploaded files).""" + pass + + +# ═══════════════════════════════════════════════════════════════════════════════ +# TEXT FORMATS +# ═══════════════════════════════════════════════════════════════════════════════ + +class TextExtractor(BaseExtractor): + """Extract from plain text files (.log, .txt, .out, etc.)""" + + extensions = [".log", ".txt", ".out", ".err", ".stdout", ".stderr"] + name = "text" + + def extract(self, file_path: str) -> ExtractionResult: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + content = f.read() + lines = content.strip().split("\n") + return ExtractionResult( + lines=lines, + source_format="text", + total_bytes=len(content), + ) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + content = data.decode("utf-8", errors="ignore") + lines = content.strip().split("\n") + return ExtractionResult( + lines=lines, + source_format="text", + total_bytes=len(data), + ) + + +class JSONExtractor(BaseExtractor): + """Extract from JSON/JSONL files.""" + + extensions = [".json", ".jsonl", ".ndjson"] + name = "json" + + def extract(self, file_path: str) -> ExtractionResult: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + content = f.read() + return self._process_content(content) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + content = data.decode("utf-8", errors="ignore") + return self._process_content(content) + + def _process_content(self, content: str) -> ExtractionResult: + lines = [] + warnings = [] + + # Try as JSONL first (one JSON per line) + raw_lines = content.strip().split("\n") + is_jsonl = True + + for line in raw_lines: + line = line.strip() + if not line: + continue + try: + json.loads(line) + lines.append(line) + except json.JSONDecodeError: + is_jsonl = False + break + + if not is_jsonl: + # Try as single JSON array + try: + data = json.loads(content) + if isinstance(data, list): + lines = [json.dumps(item) for item in data] + elif isinstance(data, dict): + # Single object - might have nested logs + if "logs" in data: + lines = [json.dumps(item) for item in data["logs"]] + elif "events" in data: + lines = [json.dumps(item) for item in data["events"]] + elif "records" in data: + lines = [json.dumps(item) for item in data["records"]] + else: + lines = [json.dumps(data)] + except json.JSONDecodeError as e: + warnings.append(f"JSON parse error: {e}") + # Fall back to raw lines + lines = raw_lines + + return ExtractionResult( + lines=lines, + source_format="json" if not is_jsonl else "jsonl", + total_bytes=len(content), + warnings=warnings, + ) + + +class XMLExtractor(BaseExtractor): + """Extract from XML log files.""" + + extensions = [".xml"] + name = "xml" + + def extract(self, file_path: str) -> ExtractionResult: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + content = f.read() + return self._process_content(content) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + content = data.decode("utf-8", errors="ignore") + return self._process_content(content) + + def _process_content(self, content: str) -> ExtractionResult: + lines = [] + warnings = [] + + try: + import xml.etree.ElementTree as ET + root = ET.fromstring(content) + + # Look for common log element patterns + log_tags = ["log", "entry", "event", "record", "message", "item", "row"] + + for tag in log_tags: + elements = root.findall(f".//{tag}") + if elements: + for elem in elements: + # Convert element to dict + entry = {child.tag: child.text for child in elem} + if elem.text and elem.text.strip(): + entry["_text"] = elem.text.strip() + entry.update(elem.attrib) + lines.append(json.dumps(entry)) + break + + if not lines: + # No standard log elements, extract all leaf text + for elem in root.iter(): + if elem.text and elem.text.strip(): + lines.append(elem.text.strip()) + + except Exception as e: + warnings.append(f"XML parse error: {e}") + # Fall back to line-by-line + lines = [l.strip() for l in content.split("\n") if l.strip()] + + return ExtractionResult( + lines=lines, + source_format="xml", + total_bytes=len(content), + warnings=warnings, + ) + + +class YAMLExtractor(BaseExtractor): + """Extract from YAML files (often used for K8s events).""" + + extensions = [".yaml", ".yml"] + name = "yaml" + + def extract(self, file_path: str) -> ExtractionResult: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + content = f.read() + return self._process_content(content) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + content = data.decode("utf-8", errors="ignore") + return self._process_content(content) + + def _process_content(self, content: str) -> ExtractionResult: + lines = [] + warnings = [] + + try: + import yaml + + # Handle multi-document YAML (---separated) + docs = list(yaml.safe_load_all(content)) + + for doc in docs: + if doc is None: + continue + if isinstance(doc, list): + for item in doc: + lines.append(json.dumps(item, default=str)) + elif isinstance(doc, dict): + # Check for items list (K8s style) + if "items" in doc: + for item in doc["items"]: + lines.append(json.dumps(item, default=str)) + else: + lines.append(json.dumps(doc, default=str)) + + except ImportError: + warnings.append("PyYAML not installed, treating as text") + lines = [l for l in content.split("\n") if l.strip() and not l.startswith("#")] + except Exception as e: + warnings.append(f"YAML parse error: {e}") + lines = [l for l in content.split("\n") if l.strip()] + + return ExtractionResult( + lines=lines, + source_format="yaml", + total_bytes=len(content), + warnings=warnings, + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# TABULAR FORMATS +# ═══════════════════════════════════════════════════════════════════════════════ + +class CSVExtractor(BaseExtractor): + """Extract from CSV/TSV files.""" + + extensions = [".csv", ".tsv"] + name = "csv" + + def extract(self, file_path: str) -> ExtractionResult: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + content = f.read() + return self._process_content(content, file_path) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + content = data.decode("utf-8", errors="ignore") + return self._process_content(content, filename) + + def _process_content(self, content: str, filename: str = "") -> ExtractionResult: + import csv + + lines = [] + warnings = [] + + # Detect delimiter + delimiter = "\t" if filename.endswith(".tsv") else "," + + try: + reader = csv.DictReader(io.StringIO(content), delimiter=delimiter) + for row in reader: + # Convert row to JSON + lines.append(json.dumps(dict(row))) + except Exception as e: + warnings.append(f"CSV parse error: {e}") + # Fall back to raw lines + lines = [l for l in content.split("\n") if l.strip()] + + return ExtractionResult( + lines=lines, + source_format="csv", + total_bytes=len(content), + warnings=warnings, + ) + + +class ParquetExtractor(BaseExtractor): + """Extract from Parquet files.""" + + extensions = [".parquet", ".pq"] + name = "parquet" + + def extract(self, file_path: str) -> ExtractionResult: + return self._process_file(file_path) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + # Write to temp file for pyarrow + with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as f: + f.write(data) + temp_path = f.name + + try: + result = self._process_file(temp_path) + finally: + Path(temp_path).unlink(missing_ok=True) + + return result + + def _process_file(self, file_path: str) -> ExtractionResult: + lines = [] + warnings = [] + total_bytes = Path(file_path).stat().st_size + + try: + import pyarrow.parquet as pq + + table = pq.read_table(file_path) + df_dict = table.to_pydict() + + # Convert to row-wise JSON + num_rows = len(next(iter(df_dict.values()))) if df_dict else 0 + for i in range(num_rows): + row = {k: v[i] for k, v in df_dict.items()} + lines.append(json.dumps(row, default=str)) + + except ImportError: + warnings.append("PyArrow not installed, cannot read Parquet") + except Exception as e: + warnings.append(f"Parquet read error: {e}") + + return ExtractionResult( + lines=lines, + source_format="parquet", + total_bytes=total_bytes, + warnings=warnings, + ) + + +class ExcelExtractor(BaseExtractor): + """Extract from Excel files.""" + + extensions = [".xlsx", ".xls"] + name = "excel" + + def extract(self, file_path: str) -> ExtractionResult: + return self._process_file(file_path) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + suffix = ".xlsx" if filename.endswith(".xlsx") else ".xls" + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f: + f.write(data) + temp_path = f.name + + try: + result = self._process_file(temp_path) + finally: + Path(temp_path).unlink(missing_ok=True) + + return result + + def _process_file(self, file_path: str) -> ExtractionResult: + lines = [] + warnings = [] + total_bytes = Path(file_path).stat().st_size + + try: + import pandas as pd + + # Read all sheets + xlsx = pd.ExcelFile(file_path) + + for sheet_name in xlsx.sheet_names: + df = pd.read_excel(xlsx, sheet_name=sheet_name) + + for _, row in df.iterrows(): + record = row.to_dict() + record["_sheet"] = sheet_name + # Clean NaN values + record = {k: (None if pd.isna(v) else v) for k, v in record.items()} + lines.append(json.dumps(record, default=str)) + + except ImportError: + warnings.append("Pandas not installed, cannot read Excel") + except Exception as e: + warnings.append(f"Excel read error: {e}") + + return ExtractionResult( + lines=lines, + source_format="excel", + total_bytes=total_bytes, + warnings=warnings, + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# COMPRESSED FORMATS +# ═══════════════════════════════════════════════════════════════════════════════ + +class GzipExtractor(BaseExtractor): + """Extract from gzip compressed files.""" + + extensions = [".gz", ".gzip"] + name = "gzip" + + def extract(self, file_path: str) -> ExtractionResult: + with gzip.open(file_path, "rb") as f: + data = f.read() + + # Determine inner format from filename + inner_name = file_path[:-3] if file_path.endswith(".gz") else file_path + return self._extract_inner(data, inner_name) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + decompressed = gzip.decompress(data) + inner_name = filename[:-3] if filename.endswith(".gz") else filename + return self._extract_inner(decompressed, inner_name) + + def _extract_inner(self, data: bytes, inner_name: str) -> ExtractionResult: + # Get appropriate extractor for inner content + extractor = get_extractor_for_file(inner_name) + if extractor and extractor.name != "gzip": + result = extractor.extract_bytes(data, inner_name) + result.source_format = f"gzip/{result.source_format}" + return result + + # Default to text + content = data.decode("utf-8", errors="ignore") + return ExtractionResult( + lines=content.strip().split("\n"), + source_format="gzip/text", + total_bytes=len(data), + ) + + +class ZipExtractor(BaseExtractor): + """Extract from ZIP archives.""" + + extensions = [".zip"] + name = "zip" + + def extract(self, file_path: str) -> ExtractionResult: + with open(file_path, "rb") as f: + return self.extract_bytes(f.read(), file_path) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + all_lines = [] + warnings = [] + file_count = 0 + + with zipfile.ZipFile(io.BytesIO(data)) as zf: + for name in zf.namelist(): + if name.endswith("/"): # Directory + continue + + try: + with zf.open(name) as f: + file_data = f.read() + + extractor = get_extractor_for_file(name) + if extractor: + result = extractor.extract_bytes(file_data, name) + all_lines.extend(result.lines) + warnings.extend(result.warnings) + else: + # Try as text + content = file_data.decode("utf-8", errors="ignore") + all_lines.extend(content.strip().split("\n")) + + file_count += 1 + except Exception as e: + warnings.append(f"Error extracting {name}: {e}") + + return ExtractionResult( + lines=all_lines, + source_format="zip", + file_count=file_count, + total_bytes=len(data), + warnings=warnings, + ) + + +class TarExtractor(BaseExtractor): + """Extract from TAR archives (.tar, .tar.gz, .tgz, .tar.bz2).""" + + extensions = [".tar", ".tar.gz", ".tgz", ".tar.bz2"] + name = "tar" + + def extract(self, file_path: str) -> ExtractionResult: + mode = "r:*" # Auto-detect compression + all_lines = [] + warnings = [] + file_count = 0 + + try: + with tarfile.open(file_path, mode) as tf: + for member in tf.getmembers(): + if not member.isfile(): + continue + + try: + f = tf.extractfile(member) + if f is None: + continue + + file_data = f.read() + extractor = get_extractor_for_file(member.name) + + if extractor: + result = extractor.extract_bytes(file_data, member.name) + all_lines.extend(result.lines) + warnings.extend(result.warnings) + else: + content = file_data.decode("utf-8", errors="ignore") + all_lines.extend(content.strip().split("\n")) + + file_count += 1 + except Exception as e: + warnings.append(f"Error extracting {member.name}: {e}") + except Exception as e: + warnings.append(f"TAR open error: {e}") + + return ExtractionResult( + lines=all_lines, + source_format="tar", + file_count=file_count, + total_bytes=Path(file_path).stat().st_size, + warnings=warnings, + ) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + with tempfile.NamedTemporaryFile(suffix=".tar", delete=False) as f: + f.write(data) + temp_path = f.name + + try: + result = self.extract(temp_path) + finally: + Path(temp_path).unlink(missing_ok=True) + + result.total_bytes = len(data) + return result + + +class Bz2Extractor(BaseExtractor): + """Extract from bzip2 compressed files.""" + + extensions = [".bz2"] + name = "bz2" + + def extract(self, file_path: str) -> ExtractionResult: + with bz2.open(file_path, "rb") as f: + data = f.read() + + inner_name = file_path[:-4] if file_path.endswith(".bz2") else file_path + return self._extract_inner(data, inner_name) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + decompressed = bz2.decompress(data) + inner_name = filename[:-4] if filename.endswith(".bz2") else filename + return self._extract_inner(decompressed, inner_name) + + def _extract_inner(self, data: bytes, inner_name: str) -> ExtractionResult: + extractor = get_extractor_for_file(inner_name) + if extractor and extractor.name != "bz2": + result = extractor.extract_bytes(data, inner_name) + result.source_format = f"bz2/{result.source_format}" + return result + + content = data.decode("utf-8", errors="ignore") + return ExtractionResult( + lines=content.strip().split("\n"), + source_format="bz2/text", + total_bytes=len(data), + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# DOCUMENT FORMATS +# ═══════════════════════════════════════════════════════════════════════════════ + +class PDFExtractor(BaseExtractor): + """Extract text from PDF files.""" + + extensions = [".pdf"] + name = "pdf" + + def extract(self, file_path: str) -> ExtractionResult: + with open(file_path, "rb") as f: + return self.extract_bytes(f.read(), file_path) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + lines = [] + warnings = [] + + # Try multiple PDF libraries + extracted = False + + # Try PyMuPDF (fitz) first - best quality + if not extracted: + try: + import fitz # PyMuPDF + + doc = fitz.open(stream=data, filetype="pdf") + for page in doc: + text = page.get_text() + lines.extend(text.strip().split("\n")) + doc.close() + extracted = True + except ImportError: + pass + except Exception as e: + warnings.append(f"PyMuPDF error: {e}") + + # Try pdfplumber + if not extracted: + try: + import pdfplumber + + with pdfplumber.open(io.BytesIO(data)) as pdf: + for page in pdf.pages: + text = page.extract_text() or "" + lines.extend(text.strip().split("\n")) + extracted = True + except ImportError: + pass + except Exception as e: + warnings.append(f"pdfplumber error: {e}") + + # Try PyPDF2 + if not extracted: + try: + from PyPDF2 import PdfReader + + reader = PdfReader(io.BytesIO(data)) + for page in reader.pages: + text = page.extract_text() or "" + lines.extend(text.strip().split("\n")) + extracted = True + except ImportError: + pass + except Exception as e: + warnings.append(f"PyPDF2 error: {e}") + + if not extracted: + warnings.append("No PDF library available. Install: pip install pymupdf pdfplumber PyPDF2") + + # Filter empty lines and clean up + lines = [l.strip() for l in lines if l.strip()] + + return ExtractionResult( + lines=lines, + source_format="pdf", + total_bytes=len(data), + warnings=warnings, + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# DATABASE FORMATS +# ═══════════════════════════════════════════════════════════════════════════════ + +class SQLiteExtractor(BaseExtractor): + """Extract from SQLite database files.""" + + extensions = [".sqlite", ".db", ".sqlite3"] + name = "sqlite" + + def extract(self, file_path: str) -> ExtractionResult: + return self._process_db(file_path) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + f.write(data) + temp_path = f.name + + try: + result = self._process_db(temp_path) + finally: + Path(temp_path).unlink(missing_ok=True) + + result.total_bytes = len(data) + return result + + def _process_db(self, file_path: str) -> ExtractionResult: + import sqlite3 + + lines = [] + warnings = [] + + try: + conn = sqlite3.connect(file_path) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + # Get all tables + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + + # Look for log-like tables first + log_tables = [t for t in tables if any(x in t.lower() + for x in ["log", "event", "audit", "trace", "message", "record"])] + + # If no log tables, use all tables + target_tables = log_tables if log_tables else tables + + for table in target_tables: + try: + cursor.execute(f"SELECT * FROM [{table}] LIMIT 10000") + columns = [desc[0] for desc in cursor.description] + + for row in cursor.fetchall(): + record = dict(zip(columns, row)) + record["_table"] = table + lines.append(json.dumps(record, default=str)) + except Exception as e: + warnings.append(f"Error reading table {table}: {e}") + + conn.close() + + except Exception as e: + warnings.append(f"SQLite error: {e}") + + return ExtractionResult( + lines=lines, + source_format="sqlite", + total_bytes=Path(file_path).stat().st_size, + warnings=warnings, + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# BINARY LOG FORMATS +# ═══════════════════════════════════════════════════════════════════════════════ + +class WindowsEventLogExtractor(BaseExtractor): + """Extract from Windows Event Log files (.evtx).""" + + extensions = [".evtx"] + name = "evtx" + + def extract(self, file_path: str) -> ExtractionResult: + return self._process_evtx(file_path) + + def extract_bytes(self, data: bytes, filename: str = "") -> ExtractionResult: + with tempfile.NamedTemporaryFile(suffix=".evtx", delete=False) as f: + f.write(data) + temp_path = f.name + + try: + result = self._process_evtx(temp_path) + finally: + Path(temp_path).unlink(missing_ok=True) + + result.total_bytes = len(data) + return result + + def _process_evtx(self, file_path: str) -> ExtractionResult: + lines = [] + warnings = [] + + try: + from evtx import PyEvtxParser + + parser = PyEvtxParser(file_path) + for record in parser.records(): + try: + lines.append(json.dumps(record, default=str)) + except: + pass + + except ImportError: + warnings.append("evtx library not installed. Install: pip install evtx") + except Exception as e: + warnings.append(f"EVTX parse error: {e}") + + return ExtractionResult( + lines=lines, + source_format="evtx", + total_bytes=Path(file_path).stat().st_size if Path(file_path).exists() else 0, + warnings=warnings, + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# EXTRACTOR REGISTRY +# ═══════════════════════════════════════════════════════════════════════════════ + +# All available extractors +EXTRACTORS: List[BaseExtractor] = [ + TextExtractor(), + JSONExtractor(), + XMLExtractor(), + YAMLExtractor(), + CSVExtractor(), + ParquetExtractor(), + ExcelExtractor(), + GzipExtractor(), + ZipExtractor(), + TarExtractor(), + Bz2Extractor(), + PDFExtractor(), + SQLiteExtractor(), + WindowsEventLogExtractor(), +] + +# Build extension -> extractor mapping +EXTENSION_MAP: dict = {} +for extractor in EXTRACTORS: + for ext in extractor.extensions: + EXTENSION_MAP[ext] = extractor + + +def get_extractor_for_file(filename: str) -> Optional[BaseExtractor]: + """Get the appropriate extractor for a file based on extension.""" + path = Path(filename) + + # Handle compound extensions like .tar.gz + suffixes = path.suffixes + if len(suffixes) >= 2: + compound = "".join(suffixes[-2:]) + if compound in EXTENSION_MAP: + return EXTENSION_MAP[compound] + + # Single extension + suffix = path.suffix.lower() + return EXTENSION_MAP.get(suffix) + + +def extract_from_file(file_path: str) -> ExtractionResult: + """ + Extract log lines from any supported file format. + + Args: + file_path: Path to the file + + Returns: + ExtractionResult with lines and metadata + """ + extractor = get_extractor_for_file(file_path) + + if extractor is None: + # Default to text extraction + extractor = TextExtractor() + + return extractor.extract(file_path) + + +def extract_from_bytes(data: bytes, filename: str) -> ExtractionResult: + """ + Extract log lines from raw bytes (e.g., uploaded file). + + Args: + data: Raw file bytes + filename: Original filename (for format detection) + + Returns: + ExtractionResult with lines and metadata + """ + extractor = get_extractor_for_file(filename) + + if extractor is None: + extractor = TextExtractor() + + return extractor.extract_bytes(data, filename) + + +def get_supported_extensions() -> List[str]: + """Get list of all supported file extensions.""" + return list(EXTENSION_MAP.keys()) + + +def get_supported_formats() -> dict: + """Get mapping of format name to extensions.""" + formats = {} + for extractor in EXTRACTORS: + formats[extractor.name] = extractor.extensions + return formats diff --git a/cascade/system/folder_processor.py b/cascade/system/folder_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..db757d722186e3d151f1ad422f092ec7dc100071 --- /dev/null +++ b/cascade/system/folder_processor.py @@ -0,0 +1,126 @@ +""" +CASCADE Folder Processor +Handle batch processing of multiple files in folders +""" + +import os +import zipfile +import tempfile +from pathlib import Path +from typing import List, Dict, Any, Tuple +import pandas as pd + +def process_folder_upload(files: List[Any]) -> Tuple[pd.DataFrame, Dict[str, Any]]: + """ + Process multiple uploaded files and combine them + + Args: + files: List of uploaded file objects from Gradio + + Returns: + Tuple of (combined_dataframe, processing_summary) + """ + if not files: + return None, {"error": "No files provided"} + + all_data = [] + file_summary = [] + total_rows = 0 + + for file_obj in files: + try: + # Get file path and info + file_path = file_obj.name + file_name = Path(file_path).name + file_ext = Path(file_path).suffix.lower() + + # Read file based on extension + df = None + + if file_ext == ".csv": + df = pd.read_csv(file_path) + elif file_ext == ".json": + df = pd.read_json(file_path) + elif file_ext == ".jsonl": + df = pd.read_json(file_path, lines=True) + elif file_ext == ".parquet": + df = pd.read_parquet(file_path) + elif file_ext in [".xlsx", ".xls"]: + df = pd.read_excel(file_path) + else: + # For other formats, try to extract text + from .file_extractors import extract_from_file + result = extract_from_file(file_path) + if result.lines: + df = pd.DataFrame([{"text": line, "source_file": file_name} + for line in result.lines]) + else: + file_summary.append({ + "file": file_name, + "status": "skipped", + "reason": "Unsupported format" + }) + continue + + # Add source file column + if df is not None and len(df) > 0: + df["source_file"] = file_name + all_data.append(df) + + file_summary.append({ + "file": file_name, + "status": "success", + "rows": len(df), + "columns": len(df.columns) + }) + total_rows += len(df) + + except Exception as e: + file_summary.append({ + "file": file_name, + "status": "error", + "error": str(e) + }) + + # Combine all data + if all_data: + combined_df = pd.concat(all_data, ignore_index=True) + + summary = { + "total_files": len(files), + "processed_files": len([s for s in file_summary if s["status"] == "success"]), + "total_rows": total_rows, + "file_details": file_summary + } + + return combined_df, summary + else: + return None, {"error": "No files could be processed", "details": file_summary} + +def process_zip_file(zip_path: str) -> Tuple[pd.DataFrame, Dict[str, Any]]: + """ + Process a zip file containing multiple files + + Args: + zip_path: Path to the zip file + + Returns: + Tuple of (combined_dataframe, processing_summary) + """ + with tempfile.TemporaryDirectory() as temp_dir: + # Extract zip + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + # Find all extracted files + extracted_files = [] + for root, dirs, files in os.walk(temp_dir): + for file in files: + file_path = os.path.join(root, file) + # Create a mock file object with name attribute + class MockFile: + def __init__(self, path): + self.name = path + extracted_files.append(MockFile(file_path)) + + return process_folder_upload(extracted_files) diff --git a/cascade/system/moe_analyzer.py b/cascade/system/moe_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e22f8320e262d2831197a3d1d22491bc1e9040 --- /dev/null +++ b/cascade/system/moe_analyzer.py @@ -0,0 +1,1017 @@ +""" +CASCADE MoE System Analyzer - Mixture of Expert Specialists. + +The MoE routes system observations to domain-specific analysts based on +detected topology. Each specialist understands the causal patterns +unique to their domain. + +Architecture: + ParsedEvents → SystemClassifier (Router) → Specialist Analyzer → CausationGraph + Insights + +Specialists: + - MLTrainingSpecialist: loss curves, gradient health, convergence + - WebServiceSpecialist: request flows, latency chains, error cascades + - MicroservicesSpecialist: distributed traces, service dependencies + - DatabaseSpecialist: query patterns, lock chains, transaction flows + - ContainerSpecialist: pod lifecycles, resource pressure, scheduling + - GenericSpecialist: fallback temporal analysis +""" + +import re +import math +from typing import List, Dict, Any, Optional, Tuple, Set +from dataclasses import dataclass, field +from abc import ABC, abstractmethod +from collections import defaultdict + +from cascade.core.event import Event, CausationLink +from cascade.core.graph import CausationGraph +from cascade.system.adapter import ParsedEvent + + +# ═══════════════════════════════════════════════════════════════════════════════ +# SYSTEM TOPOLOGY CLASSIFICATION +# ═══════════════════════════════════════════════════════════════════════════════ + +class SystemTopology: + """Detected system topology with confidence scores.""" + + ML_TRAINING = "ml_training" + WEB_SERVICE = "web_service" + MICROSERVICES = "microservices" + DATABASE = "database" + CONTAINER_ORCHESTRATION = "container_orchestration" + MESSAGE_QUEUE = "message_queue" + GENERIC = "generic" + + # Signal patterns for each topology + TOPOLOGY_SIGNALS: Dict[str, Dict[str, List[str]]] = { + ML_TRAINING: { + "components": ["model", "trainer", "optimizer", "scheduler", "checkpoint", + "dataloader", "loss", "gradient", "epoch", "batch"], + "event_types": ["checkpoint", "training", "validation", "inference", + "gradient_update", "lr_schedule", "early_stop"], + "metrics": ["loss", "accuracy", "lr", "grad_norm", "epoch", "batch", + "perplexity", "mfu", "tokens_per_sec"], + "patterns": [r"epoch.*\d+", r"loss.*\d+\.\d+", r"step.*\d+", r"grad.*norm"] + }, + WEB_SERVICE: { + "components": ["nginx", "apache", "api", "gateway", "handler", "router", + "controller", "middleware", "auth", "session"], + "event_types": ["request", "response", "redirect", "error", "timeout", + "rate_limit", "auth_failure", "cache_hit", "cache_miss"], + "metrics": ["latency", "response_time", "status_code", "request_count", + "error_rate", "throughput", "p99", "p95"], + "patterns": [r"GET|POST|PUT|DELETE", r"\d{3}\s", r"HTTP/\d\.\d", r"/api/"] + }, + MICROSERVICES: { + "components": ["service", "svc", "grpc", "rpc", "proxy", "envoy", "istio", + "consul", "eureka", "discovery"], + "event_types": ["trace", "span", "call", "rpc_start", "rpc_end", + "circuit_breaker", "retry", "fallback", "timeout"], + "metrics": ["trace_id", "span_id", "parent_span", "duration_ms", + "service_latency", "hop_count", "retry_count"], + "patterns": [r"trace[-_]?id", r"span[-_]?id", r"correlation[-_]?id"] + }, + DATABASE: { + "components": ["mysql", "postgres", "postgresql", "mongodb", "redis", + "oracle", "sqlserver", "query", "transaction", "connection"], + "event_types": ["query", "transaction_start", "transaction_commit", + "transaction_rollback", "deadlock", "lock_wait", "slow_query"], + "metrics": ["query_time", "rows_affected", "rows_scanned", "lock_time", + "connections", "buffer_pool", "cache_hit_ratio"], + "patterns": [r"SELECT|INSERT|UPDATE|DELETE", r"BEGIN|COMMIT|ROLLBACK"] + }, + CONTAINER_ORCHESTRATION: { + "components": ["kubernetes", "k8s", "docker", "pod", "container", "node", + "deployment", "kubelet", "scheduler", "controller"], + "event_types": ["pod_scheduled", "container_started", "container_killed", + "oom_killed", "evicted", "scaling", "rolling_update"], + "metrics": ["cpu_usage", "memory_usage", "cpu_limit", "memory_limit", + "restart_count", "ready_replicas", "available_replicas"], + "patterns": [r"pod/", r"deployment/", r"namespace/", r"k8s\.io"] + }, + MESSAGE_QUEUE: { + "components": ["kafka", "rabbitmq", "sqs", "pubsub", "nats", "activemq", + "producer", "consumer", "queue", "topic", "broker"], + "event_types": ["message_published", "message_consumed", "message_acked", + "message_nacked", "consumer_lag", "rebalance"], + "metrics": ["queue_depth", "consumer_lag", "messages_per_sec", + "partition_offset", "consumer_group_lag"], + "patterns": [r"topic[-_]", r"partition[-_]?\d+", r"offset[-_]?\d+"] + }, + } + + +@dataclass +class TopologyClassification: + """Result of system topology classification.""" + primary: str # Primary detected topology + confidence: float # 0.0 - 1.0 + all_scores: Dict[str, float] # Scores for all topologies + evidence: Dict[str, List[str]] # What signals matched + hybrid: bool = False # Is this a hybrid system? + secondary: Optional[str] = None # Secondary topology if hybrid + + +class SystemClassifier: + """ + MoE Router - Classifies system topology from parsed events. + + Examines components, event types, metrics, and text patterns + to determine what kind of system produced these logs. + """ + + def __init__(self): + self.signals = SystemTopology.TOPOLOGY_SIGNALS + + def classify(self, events: List[ParsedEvent]) -> TopologyClassification: + """ + Classify the system topology from parsed events. + + Args: + events: List of parsed events from adapter + + Returns: + TopologyClassification with scores and evidence + """ + if not events: + return TopologyClassification( + primary=SystemTopology.GENERIC, + confidence=0.0, + all_scores={}, + evidence={}, + ) + + # Collect all signals from events + all_components: Set[str] = set() + all_event_types: Set[str] = set() + all_metrics: Set[str] = set() + all_text: List[str] = [] + + for event in events: + all_components.add(event.component.lower()) + all_event_types.add(event.event_type.lower()) + + # Extract metrics from data + if isinstance(event.data, dict): + all_metrics.update(k.lower() for k in event.data.keys()) + + # Collect raw text for pattern matching + if event.data.get("message"): + all_text.append(str(event.data["message"])) + if event.data.get("raw"): + all_text.append(str(event.data["raw"])) + + combined_text = " ".join(all_text) + + # Score each topology + scores: Dict[str, float] = {} + evidence: Dict[str, List[str]] = {} + + for topology, signals in self.signals.items(): + score, matched = self._score_topology( + topology, signals, + all_components, all_event_types, all_metrics, combined_text + ) + scores[topology] = score + evidence[topology] = matched + + # Determine primary topology + if not scores or max(scores.values()) == 0: + return TopologyClassification( + primary=SystemTopology.GENERIC, + confidence=0.0, + all_scores=scores, + evidence=evidence, + ) + + # Sort by score + sorted_topologies = sorted(scores.items(), key=lambda x: x[1], reverse=True) + primary = sorted_topologies[0][0] + primary_score = sorted_topologies[0][1] + + # Normalize confidence to 0-1 + max_possible = 4.0 # 4 signal types + confidence = min(primary_score / max_possible, 1.0) + + # Check for hybrid (second topology has significant score) + hybrid = False + secondary = None + if len(sorted_topologies) > 1: + second_score = sorted_topologies[1][1] + if second_score > primary_score * 0.5: # Second is at least 50% of primary + hybrid = True + secondary = sorted_topologies[1][0] + + return TopologyClassification( + primary=primary, + confidence=confidence, + all_scores=scores, + evidence=evidence, + hybrid=hybrid, + secondary=secondary, + ) + + def _score_topology( + self, + topology: str, + signals: Dict[str, List[str]], + components: Set[str], + event_types: Set[str], + metrics: Set[str], + text: str, + ) -> Tuple[float, List[str]]: + """Score how well events match a topology.""" + score = 0.0 + matched = [] + + # Component matches + comp_matches = components & set(s.lower() for s in signals.get("components", [])) + if comp_matches: + score += len(comp_matches) / max(len(signals.get("components", [1])), 1) + matched.extend([f"component:{c}" for c in list(comp_matches)[:3]]) + + # Event type matches + type_matches = event_types & set(s.lower() for s in signals.get("event_types", [])) + if type_matches: + score += len(type_matches) / max(len(signals.get("event_types", [1])), 1) + matched.extend([f"type:{t}" for t in list(type_matches)[:3]]) + + # Metric matches + metric_signals = set(s.lower() for s in signals.get("metrics", [])) + metric_matches = metrics & metric_signals + # Also check partial matches (e.g., "train_loss" contains "loss") + for metric in metrics: + for signal in metric_signals: + if signal in metric and signal not in metric_matches: + metric_matches.add(signal) + if metric_matches: + score += len(metric_matches) / max(len(signals.get("metrics", [1])), 1) + matched.extend([f"metric:{m}" for m in list(metric_matches)[:3]]) + + # Pattern matches + text_lower = text.lower() + pattern_matches = 0 + for pattern in signals.get("patterns", []): + if re.search(pattern, text, re.IGNORECASE): + pattern_matches += 1 + matched.append(f"pattern:{pattern[:20]}") + if pattern_matches: + score += pattern_matches / max(len(signals.get("patterns", [1])), 1) + + return score, matched + + +# ═══════════════════════════════════════════════════════════════════════════════ +# SPECIALIST ANALYZERS - Domain Experts +# ═══════════════════════════════════════════════════════════════════════════════ + +@dataclass +class AnalysisInsight: + """A single insight from specialist analysis.""" + category: str # "causal", "anomaly", "pattern", "recommendation" + severity: str # "info", "warning", "critical" + title: str + description: str + evidence: List[str] = field(default_factory=list) + affected_events: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "category": self.category, + "severity": self.severity, + "title": self.title, + "description": self.description, + "evidence": self.evidence, + "affected_events": self.affected_events, + } + + +@dataclass +class SpecialistAnalysis: + """Complete analysis from a specialist.""" + topology: str + specialist: str + confidence: float + insights: List[AnalysisInsight] = field(default_factory=list) + causal_links: List[CausationLink] = field(default_factory=list) + metrics_summary: Dict[str, Any] = field(default_factory=dict) + narrative: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "topology": self.topology, + "specialist": self.specialist, + "confidence": self.confidence, + "insights": [i.to_dict() for i in self.insights], + "causal_links_count": len(self.causal_links), + "metrics_summary": self.metrics_summary, + "narrative": self.narrative, + } + + +class BaseSpecialist(ABC): + """Base class for domain specialist analyzers.""" + + name: str = "base" + topology: str = SystemTopology.GENERIC + + @abstractmethod + def analyze(self, events: List[ParsedEvent], graph: CausationGraph) -> SpecialistAnalysis: + """ + Analyze events and build causal understanding. + + Args: + events: Parsed events from the system + graph: CausationGraph to populate with causal links + + Returns: + SpecialistAnalysis with insights and causal links + """ + pass + + def _find_temporal_chains( + self, + events: List[ParsedEvent], + window_ms: float = 1000.0 + ) -> List[Tuple[ParsedEvent, ParsedEvent]]: + """Find temporally close event pairs (potential causation).""" + chains = [] + sorted_events = sorted(events, key=lambda e: e.timestamp) + + for i, event in enumerate(sorted_events[:-1]): + for j in range(i + 1, min(i + 10, len(sorted_events))): + next_event = sorted_events[j] + time_delta = (next_event.timestamp - event.timestamp) * 1000 + if 0 < time_delta <= window_ms: + chains.append((event, next_event)) + elif time_delta > window_ms: + break + + return chains + + def _create_causal_link( + self, + from_event: ParsedEvent, + to_event: ParsedEvent, + causation_type: str, + strength: float = 0.5, + explanation: str = "", + ) -> CausationLink: + """Create a causal link between events.""" + return CausationLink( + from_event=from_event.event_hash, + to_event=to_event.event_hash, + causation_type=causation_type, + strength=strength, + explanation=explanation or f"{from_event.component} → {to_event.component}", + ) + + +class MLTrainingSpecialist(BaseSpecialist): + """ + Specialist for ML Training systems. + + Understands: + - Training loops (epoch → batch → step) + - Loss dynamics and convergence + - Gradient health (vanishing/exploding) + - Checkpoint causation + - Learning rate schedules + """ + + name = "ml_training_specialist" + topology = SystemTopology.ML_TRAINING + + def analyze(self, events: List[ParsedEvent], graph: CausationGraph) -> SpecialistAnalysis: + analysis = SpecialistAnalysis( + topology=self.topology, + specialist=self.name, + confidence=0.0, + ) + + # Group events by epoch/step + epochs: Dict[int, List[ParsedEvent]] = defaultdict(list) + losses: List[Tuple[float, float]] = [] # (timestamp, loss) + grad_norms: List[Tuple[float, float]] = [] + checkpoints: List[ParsedEvent] = [] + + for event in events: + data = event.data or {} + + # Extract epoch + epoch = data.get("epoch") or data.get("ep") + if epoch is not None: + try: + epochs[int(epoch)].append(event) + except (ValueError, TypeError): + pass + + # Extract loss + loss = data.get("loss") or data.get("train_loss") + if loss is not None: + try: + losses.append((event.timestamp, float(loss))) + except (ValueError, TypeError): + pass + + # Extract gradient norm + grad = data.get("grad_norm") or data.get("gradient_norm") + if grad is not None: + try: + grad_norms.append((event.timestamp, float(grad))) + except (ValueError, TypeError): + pass + + # Find checkpoints + if event.event_type == "checkpoint" or "checkpoint" in event.component.lower(): + checkpoints.append(event) + + # Build causal chains: epoch flow + sorted_events = sorted(events, key=lambda e: e.timestamp) + prev_event = None + for event in sorted_events: + if prev_event and prev_event.component == event.component: + # Same component temporal chain + link = self._create_causal_link( + prev_event, event, + causation_type="temporal_sequence", + strength=0.7, + explanation=f"Sequential {event.component} events" + ) + analysis.causal_links.append(link) + prev_event = event + + # Checkpoint causation (checkpoint follows training events) + for checkpoint in checkpoints: + # Find training events just before checkpoint + for event in sorted_events: + if event.timestamp < checkpoint.timestamp: + time_delta = checkpoint.timestamp - event.timestamp + if time_delta < 60: # Within 60 seconds + link = self._create_causal_link( + event, checkpoint, + causation_type="checkpoint_trigger", + strength=0.8, + explanation="Training event triggered checkpoint" + ) + analysis.causal_links.append(link) + break # Only link to most recent + + # Analyze loss curve + if len(losses) >= 3: + losses_sorted = sorted(losses, key=lambda x: x[0]) + loss_values = [l[1] for l in losses_sorted] + + # Check convergence + if loss_values[-1] < loss_values[0]: + trend = "converging" + analysis.insights.append(AnalysisInsight( + category="pattern", + severity="info", + title="Training Converging", + description=f"Loss decreased from {loss_values[0]:.4f} to {loss_values[-1]:.4f}", + evidence=[f"loss_start={loss_values[0]:.4f}", f"loss_end={loss_values[-1]:.4f}"] + )) + else: + trend = "diverging" + analysis.insights.append(AnalysisInsight( + category="anomaly", + severity="warning", + title="Training May Be Diverging", + description=f"Loss increased from {loss_values[0]:.4f} to {loss_values[-1]:.4f}", + evidence=[f"loss_start={loss_values[0]:.4f}", f"loss_end={loss_values[-1]:.4f}"] + )) + + # Check for loss spikes + mean_loss = sum(loss_values) / len(loss_values) + for i, loss in enumerate(loss_values): + if loss > mean_loss * 3: + analysis.insights.append(AnalysisInsight( + category="anomaly", + severity="critical", + title="Loss Spike Detected", + description=f"Loss spiked to {loss:.4f} (mean: {mean_loss:.4f})", + evidence=[f"spike_value={loss:.4f}", f"mean={mean_loss:.4f}"] + )) + + # Analyze gradient health + if grad_norms: + grad_values = [g[1] for g in grad_norms] + mean_grad = sum(grad_values) / len(grad_values) + + if mean_grad < 1e-7: + analysis.insights.append(AnalysisInsight( + category="anomaly", + severity="critical", + title="Vanishing Gradients", + description=f"Mean gradient norm {mean_grad:.2e} is dangerously low", + evidence=[f"mean_grad_norm={mean_grad:.2e}"] + )) + elif mean_grad > 100: + analysis.insights.append(AnalysisInsight( + category="anomaly", + severity="critical", + title="Exploding Gradients", + description=f"Mean gradient norm {mean_grad:.2f} is very high", + evidence=[f"mean_grad_norm={mean_grad:.2f}"] + )) + + # Build narrative + analysis.confidence = min(len(events) / 50, 1.0) + analysis.metrics_summary = { + "epochs_observed": len(epochs), + "loss_samples": len(losses), + "gradient_samples": len(grad_norms), + "checkpoints": len(checkpoints), + } + + analysis.narrative = self._build_narrative(analysis, losses, grad_norms, epochs) + + return analysis + + def _build_narrative( + self, + analysis: SpecialistAnalysis, + losses: List[Tuple[float, float]], + grad_norms: List[Tuple[float, float]], + epochs: Dict[int, List[ParsedEvent]] + ) -> str: + """Build human-readable narrative.""" + parts = [f"🧠 **ML Training Analysis** (confidence: {analysis.confidence:.0%})"] + + if epochs: + parts.append(f"\n📊 Observed {len(epochs)} epochs") + if losses: + loss_values = [l[1] for l in losses] + parts.append(f"📉 Loss range: {min(loss_values):.4f} → {max(loss_values):.4f}") + if grad_norms: + grad_values = [g[1] for g in grad_norms] + parts.append(f"🔬 Gradient norm range: {min(grad_values):.2e} → {max(grad_values):.2e}") + + if analysis.insights: + critical = [i for i in analysis.insights if i.severity == "critical"] + warnings = [i for i in analysis.insights if i.severity == "warning"] + if critical: + parts.append(f"\n⚠️ **{len(critical)} critical issues detected**") + if warnings: + parts.append(f"⚡ {len(warnings)} warnings") + + parts.append(f"\n🔗 Identified {len(analysis.causal_links)} causal relationships") + + return "\n".join(parts) + + +class WebServiceSpecialist(BaseSpecialist): + """ + Specialist for Web Service systems. + + Understands: + - Request → Response chains + - Error cascades (4xx → retry → 5xx) + - Latency bottlenecks + - Rate limiting effects + """ + + name = "web_service_specialist" + topology = SystemTopology.WEB_SERVICE + + def analyze(self, events: List[ParsedEvent], graph: CausationGraph) -> SpecialistAnalysis: + analysis = SpecialistAnalysis( + topology=self.topology, + specialist=self.name, + confidence=0.0, + ) + + # Categorize events + requests: List[ParsedEvent] = [] + errors: List[ParsedEvent] = [] + latencies: List[float] = [] + status_codes: Dict[int, int] = defaultdict(int) + + for event in events: + data = event.data or {} + + # Identify requests + if event.event_type in ["request", "response"] or "request" in str(data).lower(): + requests.append(event) + + # Identify errors + if event.event_type in ["error", "exception", "failure"]: + errors.append(event) + + # Extract status codes + status = data.get("status") or data.get("status_code") or data.get("http_status") + if status: + try: + status_codes[int(status)] += 1 + except (ValueError, TypeError): + pass + + # Extract latency + latency = data.get("latency") or data.get("response_time") or data.get("duration") + if latency: + try: + latencies.append(float(latency)) + except (ValueError, TypeError): + pass + + # Build request → error causal chains + sorted_events = sorted(events, key=lambda e: e.timestamp) + for i, event in enumerate(sorted_events): + if event.event_type in ["error", "exception", "failure"]: + # Look for preceding request + for j in range(i - 1, max(i - 10, -1), -1): + prev = sorted_events[j] + if prev.event_type in ["request", "info"]: + time_delta = event.timestamp - prev.timestamp + if time_delta < 30: # Within 30 seconds + link = self._create_causal_link( + prev, event, + causation_type="request_failure", + strength=0.8, + explanation="Request led to error response" + ) + analysis.causal_links.append(link) + break + + # Analyze error patterns + error_count = len(errors) + total_count = len(events) + if total_count > 0: + error_rate = error_count / total_count + if error_rate > 0.1: + analysis.insights.append(AnalysisInsight( + category="anomaly", + severity="critical" if error_rate > 0.3 else "warning", + title="High Error Rate", + description=f"Error rate is {error_rate:.1%} ({error_count}/{total_count} events)", + evidence=[f"error_rate={error_rate:.1%}"] + )) + + # Analyze status codes + error_statuses = sum(v for k, v in status_codes.items() if k >= 400) + if error_statuses > 0: + analysis.insights.append(AnalysisInsight( + category="pattern", + severity="warning" if error_statuses > 10 else "info", + title="HTTP Errors Detected", + description=f"Found {error_statuses} 4xx/5xx responses", + evidence=[f"{k}: {v}" for k, v in status_codes.items() if k >= 400] + )) + + # Analyze latency + if latencies: + avg_latency = sum(latencies) / len(latencies) + max_latency = max(latencies) + p95_idx = int(len(latencies) * 0.95) + p95 = sorted(latencies)[p95_idx] if p95_idx < len(latencies) else max_latency + + analysis.metrics_summary["avg_latency"] = avg_latency + analysis.metrics_summary["max_latency"] = max_latency + analysis.metrics_summary["p95_latency"] = p95 + + if p95 > avg_latency * 3: + analysis.insights.append(AnalysisInsight( + category="anomaly", + severity="warning", + title="Latency Outliers", + description=f"P95 latency ({p95:.0f}ms) is 3x+ average ({avg_latency:.0f}ms)", + evidence=[f"avg={avg_latency:.0f}ms", f"p95={p95:.0f}ms"] + )) + + analysis.confidence = min(len(events) / 100, 1.0) + analysis.metrics_summary["total_requests"] = len(requests) + analysis.metrics_summary["total_errors"] = error_count + analysis.metrics_summary["status_codes"] = dict(status_codes) + + analysis.narrative = self._build_narrative(analysis, status_codes, latencies) + + return analysis + + def _build_narrative( + self, + analysis: SpecialistAnalysis, + status_codes: Dict[int, int], + latencies: List[float] + ) -> str: + parts = [f"🌐 **Web Service Analysis** (confidence: {analysis.confidence:.0%})"] + + total_requests = sum(status_codes.values()) + if total_requests: + success = sum(v for k, v in status_codes.items() if 200 <= k < 400) + parts.append(f"\n📊 {total_requests} requests, {success} successful") + + if latencies: + parts.append(f"⏱️ Avg latency: {sum(latencies)/len(latencies):.0f}ms") + + if analysis.insights: + critical = [i for i in analysis.insights if i.severity == "critical"] + if critical: + parts.append(f"\n🚨 **{len(critical)} critical issues**") + + parts.append(f"\n🔗 {len(analysis.causal_links)} causal chains identified") + + return "\n".join(parts) + + +class MicroservicesSpecialist(BaseSpecialist): + """Specialist for distributed microservices systems.""" + + name = "microservices_specialist" + topology = SystemTopology.MICROSERVICES + + def analyze(self, events: List[ParsedEvent], graph: CausationGraph) -> SpecialistAnalysis: + analysis = SpecialistAnalysis( + topology=self.topology, + specialist=self.name, + confidence=0.0, + ) + + # Group by trace_id + traces: Dict[str, List[ParsedEvent]] = defaultdict(list) + services: Set[str] = set() + + for event in events: + data = event.data or {} + trace_id = data.get("trace_id") or data.get("traceId") or data.get("correlation_id") + if trace_id: + traces[str(trace_id)].append(event) + services.add(event.component) + + # Build service dependency graph from traces + service_calls: Dict[Tuple[str, str], int] = defaultdict(int) + + for trace_id, trace_events in traces.items(): + sorted_trace = sorted(trace_events, key=lambda e: e.timestamp) + for i in range(len(sorted_trace) - 1): + from_svc = sorted_trace[i].component + to_svc = sorted_trace[i + 1].component + if from_svc != to_svc: + service_calls[(from_svc, to_svc)] += 1 + + # Create causal link + link = self._create_causal_link( + sorted_trace[i], sorted_trace[i + 1], + causation_type="service_call", + strength=0.9, + explanation=f"Trace {trace_id[:8]}... call chain" + ) + analysis.causal_links.append(link) + + # Identify hot paths + if service_calls: + hottest = max(service_calls.items(), key=lambda x: x[1]) + analysis.insights.append(AnalysisInsight( + category="pattern", + severity="info", + title="Hot Service Path", + description=f"{hottest[0][0]} → {hottest[0][1]} called {hottest[1]} times", + evidence=[f"call_count={hottest[1]}"] + )) + + analysis.confidence = min(len(traces) / 20, 1.0) + analysis.metrics_summary = { + "services": list(services), + "traces": len(traces), + "service_calls": len(service_calls), + } + + analysis.narrative = f"🔀 **Microservices Analysis**\n{len(services)} services, {len(traces)} traces\n🔗 {len(analysis.causal_links)} call chains" + + return analysis + + +class GenericSpecialist(BaseSpecialist): + """Fallback specialist for unrecognized systems.""" + + name = "generic_specialist" + topology = SystemTopology.GENERIC + + def analyze(self, events: List[ParsedEvent], graph: CausationGraph) -> SpecialistAnalysis: + analysis = SpecialistAnalysis( + topology=self.topology, + specialist=self.name, + confidence=0.3, # Low confidence for generic + ) + + # Basic temporal chaining + sorted_events = sorted(events, key=lambda e: e.timestamp) + components = defaultdict(list) + + for event in sorted_events: + components[event.component].append(event) + + # Chain events within same component + for comp, comp_events in components.items(): + for i in range(len(comp_events) - 1): + link = self._create_causal_link( + comp_events[i], comp_events[i + 1], + causation_type="temporal", + strength=0.5, + explanation=f"Temporal sequence in {comp}" + ) + analysis.causal_links.append(link) + + # Find error cascades + errors = [e for e in sorted_events if e.event_type in ["error", "exception", "failure", "warning"]] + if errors: + analysis.insights.append(AnalysisInsight( + category="pattern", + severity="warning" if len(errors) > 5 else "info", + title="Error Events Detected", + description=f"Found {len(errors)} error/warning events", + evidence=[f"error_count={len(errors)}"] + )) + + analysis.metrics_summary = { + "total_events": len(events), + "components": len(components), + "error_events": len(errors), + } + + analysis.narrative = f"📋 **Generic Analysis**\n{len(events)} events across {len(components)} components\n🔗 {len(analysis.causal_links)} temporal chains" + + return analysis + + +# ═══════════════════════════════════════════════════════════════════════════════ +# MoE ANALYZER - The Router + Specialists Combined +# ═══════════════════════════════════════════════════════════════════════════════ + +class MoEAnalyzer: + """ + Mixture of Experts System Analyzer. + + Routes system observations to domain-specific specialists based on + detected topology. Combines classification + specialist analysis + into a unified analysis pipeline. + + Usage: + analyzer = MoEAnalyzer() + result = analyzer.analyze(parsed_events) + + print(result.classification) # What system was detected + print(result.analysis) # Deep specialist analysis + print(result.graph) # Populated CausationGraph + """ + + def __init__(self): + self.classifier = SystemClassifier() + + # Register specialists + self.specialists: Dict[str, BaseSpecialist] = { + SystemTopology.ML_TRAINING: MLTrainingSpecialist(), + SystemTopology.WEB_SERVICE: WebServiceSpecialist(), + SystemTopology.MICROSERVICES: MicroservicesSpecialist(), + SystemTopology.GENERIC: GenericSpecialist(), + # Add more as needed: + # SystemTopology.DATABASE: DatabaseSpecialist(), + # SystemTopology.CONTAINER_ORCHESTRATION: ContainerSpecialist(), + } + + self.default_specialist = GenericSpecialist() + + def analyze(self, events: List[ParsedEvent]) -> 'MoEAnalysisResult': + """ + Analyze events through the MoE pipeline. + + 1. Classify system topology + 2. Route to appropriate specialist + 3. Build causation graph + 4. Return combined analysis + + Args: + events: Parsed events from UniversalAdapter + + Returns: + MoEAnalysisResult with classification, analysis, and graph + """ + # Step 1: Classify + classification = self.classifier.classify(events) + + # Step 2: Create causation graph + graph = CausationGraph() + + # Add all events to graph + for event in events: + cascade_event = Event( + timestamp=event.timestamp, + event_type=event.event_type, + component=event.component, + data={ + **event.data, + "hash": event.event_hash, + "parent_hash": event.parent_hash, + }, + ) + graph.add_event(cascade_event) + + # Step 3: Route to specialist + specialist = self.specialists.get( + classification.primary, + self.default_specialist + ) + + analysis = specialist.analyze(events, graph) + + # Add causal links to graph + for link in analysis.causal_links: + graph.add_link(link) + + # Step 4: If hybrid, also run secondary specialist + secondary_analysis = None + if classification.hybrid and classification.secondary: + secondary_specialist = self.specialists.get( + classification.secondary, + self.default_specialist + ) + if secondary_specialist.name != specialist.name: + secondary_analysis = secondary_specialist.analyze(events, graph) + for link in secondary_analysis.causal_links: + graph.add_link(link) + + return MoEAnalysisResult( + classification=classification, + primary_analysis=analysis, + secondary_analysis=secondary_analysis, + graph=graph, + events=events, + ) + + def get_available_specialists(self) -> List[str]: + """List all available specialists.""" + return list(self.specialists.keys()) + + +@dataclass +class MoEAnalysisResult: + """Complete result from MoE analysis pipeline.""" + + classification: TopologyClassification + primary_analysis: SpecialistAnalysis + secondary_analysis: Optional[SpecialistAnalysis] + graph: CausationGraph + events: List[ParsedEvent] + + def to_dict(self) -> Dict[str, Any]: + """Serialize for JSON/display.""" + return { + "classification": { + "primary": self.classification.primary, + "confidence": self.classification.confidence, + "hybrid": self.classification.hybrid, + "secondary": self.classification.secondary, + "all_scores": self.classification.all_scores, + "evidence": self.classification.evidence, + }, + "primary_analysis": self.primary_analysis.to_dict(), + "secondary_analysis": self.secondary_analysis.to_dict() if self.secondary_analysis else None, + "graph_stats": self.graph.get_stats(), + "event_count": len(self.events), + } + + def get_narrative(self) -> str: + """Get combined narrative from all analyses.""" + parts = [] + + # Classification summary + parts.append(f"## 🔍 System Classification") + parts.append(f"**Detected:** {self.classification.primary.replace('_', ' ').title()}") + parts.append(f"**Confidence:** {self.classification.confidence:.0%}") + + if self.classification.hybrid: + parts.append(f"**Hybrid with:** {self.classification.secondary.replace('_', ' ').title()}") + + # Evidence + if self.classification.evidence.get(self.classification.primary): + parts.append(f"\n**Evidence:** {', '.join(self.classification.evidence[self.classification.primary][:5])}") + + # Primary analysis + parts.append(f"\n---\n{self.primary_analysis.narrative}") + + # Secondary analysis + if self.secondary_analysis: + parts.append(f"\n---\n{self.secondary_analysis.narrative}") + + # Graph stats + stats = self.graph.get_stats() + parts.append(f"\n---\n## 🕸️ Causation Graph") + parts.append(f"- **Events:** {stats['event_count']}") + parts.append(f"- **Causal Links:** {stats['link_count']}") + parts.append(f"- **Root Causes:** {stats['root_count']}") + parts.append(f"- **Leaf Effects:** {stats['leaf_count']}") + + return "\n".join(parts) + + def get_all_insights(self) -> List[AnalysisInsight]: + """Get all insights from all analyses.""" + insights = list(self.primary_analysis.insights) + if self.secondary_analysis: + insights.extend(self.secondary_analysis.insights) + return insights + + def get_critical_insights(self) -> List[AnalysisInsight]: + """Get only critical severity insights.""" + return [i for i in self.get_all_insights() if i.severity == "critical"] diff --git a/cascade/system/observer.py b/cascade/system/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..171ec8b8aa57c8fef0bf20d73d774da2d0b48bb6 --- /dev/null +++ b/cascade/system/observer.py @@ -0,0 +1,311 @@ +""" +CASCADE System Observer - Process logs into CASCADE events. + +Observes log files or streams and emits CASCADE-compatible events +with full hash chain provenance. +""" + +import time +import json +import hashlib +from pathlib import Path +from typing import Optional, Dict, Any, List, Generator, Union +from dataclasses import dataclass, field + +from cascade.system.adapter import ( + LogAdapter, ParsedEvent, auto_detect_adapter, + JSONLAdapter, ApacheLogAdapter, GenericLogAdapter, +) +from cascade.core.event import Event +from cascade.analysis.metrics import MetricsEngine + + +@dataclass +class SystemObservation: + """Result of observing a log source.""" + + source: str # File path or stream name + adapter_name: str + events: List[ParsedEvent] = field(default_factory=list) + + # Hash chain + merkle_root: str = "" + chain_length: int = 0 + + # Statistics + event_counts: Dict[str, int] = field(default_factory=dict) # type -> count + component_counts: Dict[str, int] = field(default_factory=dict) # component -> count + time_range: tuple = (0.0, 0.0) # (min_ts, max_ts) + + # Errors + parse_errors: int = 0 + + def compute_merkle_root(self) -> str: + """Compute Merkle root of all event hashes.""" + if not self.events: + return "" + + hashes = [e.event_hash for e in self.events] + + # Build Merkle tree + while len(hashes) > 1: + if len(hashes) % 2 == 1: + hashes.append(hashes[-1]) # Duplicate last if odd + + new_level = [] + for i in range(0, len(hashes), 2): + combined = hashes[i] + hashes[i + 1] + new_hash = hashlib.sha256(combined.encode()).hexdigest()[:16] + new_level.append(new_hash) + hashes = new_level + + self.merkle_root = hashes[0] if hashes else "" + self.chain_length = len(self.events) + return self.merkle_root + + def to_summary(self) -> Dict[str, Any]: + """Generate summary for display.""" + return { + "source": self.source, + "adapter": self.adapter_name, + "total_events": len(self.events), + "merkle_root": self.merkle_root, + "chain_length": self.chain_length, + "event_types": self.event_counts, + "components": self.component_counts, + "time_range": { + "start": self.time_range[0], + "end": self.time_range[1], + "duration_sec": self.time_range[1] - self.time_range[0] if self.time_range[1] > 0 else 0, + }, + "parse_errors": self.parse_errors, + } + + +class SystemObserver: + """ + Observe system logs and emit CASCADE events. + + This is the bridge between external system logs and CASCADE visualization. + """ + + def __init__(self, adapter: LogAdapter = None): + """ + Args: + adapter: Log adapter to use. If None, will auto-detect. + """ + self.adapter = adapter + self.observations: List[SystemObservation] = [] + self.metrics_engine = MetricsEngine() + + def observe_file(self, filepath: str, adapter: LogAdapter = None) -> Generator[Event, None, SystemObservation]: + """ + Observe a log file and emit CASCADE events. + + Args: + filepath: Path to log file + adapter: Override adapter (auto-detect if None) + + Yields: + CASCADE Event objects + + Returns: + SystemObservation with summary and provenance + """ + path = Path(filepath) + if not path.exists(): + raise FileNotFoundError(f"Log file not found: {filepath}") + + # Auto-detect adapter if needed + if adapter is None: + with open(filepath, "r", encoding="utf-8", errors="ignore") as f: + sample = [f.readline() for _ in range(20)] + adapter = auto_detect_adapter(sample) + + observation = SystemObservation( + source=str(filepath), + adapter_name=adapter.name, + ) + + min_ts = float("inf") + max_ts = 0.0 + + # Parse file + for parsed in adapter.parse_file(filepath): + observation.events.append(parsed) + + # Update stats + observation.event_counts[parsed.event_type] = observation.event_counts.get(parsed.event_type, 0) + 1 + observation.component_counts[parsed.component] = observation.component_counts.get(parsed.component, 0) + 1 + + min_ts = min(min_ts, parsed.timestamp) + max_ts = max(max_ts, parsed.timestamp) + + # Convert to CASCADE Event and yield + cascade_event = Event( + timestamp=parsed.timestamp, + event_type=parsed.event_type, + component=parsed.component, + data={ + **parsed.data, + "hash": parsed.event_hash, + "parent_hash": parsed.parent_hash, + "source": "system_log", + "adapter": adapter.name, + }, + ) + + self.metrics_engine.ingest(cascade_event) + yield cascade_event + + # Finalize observation + observation.time_range = (min_ts if min_ts != float("inf") else 0, max_ts) + observation.compute_merkle_root() + self.observations.append(observation) + + return observation + + def observe_lines(self, lines: List[str], source_name: str = "input", adapter: LogAdapter = None) -> Generator[Event, None, SystemObservation]: + """ + Observe log lines (e.g., from text input or upload). + + Args: + lines: List of log lines + source_name: Name for this source + adapter: Override adapter (auto-detect if None) + + Yields: + CASCADE Event objects + """ + if adapter is None: + adapter = auto_detect_adapter(lines[:20]) + + observation = SystemObservation( + source=source_name, + adapter_name=adapter.name, + ) + + min_ts = float("inf") + max_ts = 0.0 + + for parsed in adapter.parse_lines(lines): + observation.events.append(parsed) + + observation.event_counts[parsed.event_type] = observation.event_counts.get(parsed.event_type, 0) + 1 + observation.component_counts[parsed.component] = observation.component_counts.get(parsed.component, 0) + 1 + + min_ts = min(min_ts, parsed.timestamp) + max_ts = max(max_ts, parsed.timestamp) + + cascade_event = Event( + timestamp=parsed.timestamp, + event_type=parsed.event_type, + component=parsed.component, + data={ + **parsed.data, + "hash": parsed.event_hash, + "parent_hash": parsed.parent_hash, + "source": "system_log", + "adapter": adapter.name, + }, + ) + + self.metrics_engine.ingest(cascade_event) + yield cascade_event + + observation.time_range = (min_ts if min_ts != float("inf") else 0, max_ts) + observation.compute_merkle_root() + self.observations.append(observation) + + return observation + + def get_all_events_for_viz(self) -> List[Dict[str, Any]]: + """Get all events formatted for CASCADE visualization.""" + all_events = [] + for obs in self.observations: + for parsed in obs.events: + all_events.append({ + "event": parsed.to_cascade_event(), + "metrics": self.metrics_engine.summary(), + "triage": self.metrics_engine.triage(), + }) + return all_events + + def get_provenance_summary(self) -> Dict[str, Any]: + """Get provenance summary for all observations.""" + return { + "observations": [obs.to_summary() for obs in self.observations], + "total_events": sum(len(obs.events) for obs in self.observations), + "total_sources": len(self.observations), + } + + +def observe_log_file(filepath: str) -> Generator[Event, None, SystemObservation]: + """ + Convenience function to observe a log file. + + Usage: + for event in observe_log_file("access.log"): + print(event) + """ + observer = SystemObserver() + return observer.observe_file(filepath) + + +def observe_log_stream(lines: List[str], source: str = "stream") -> Generator[Event, None, SystemObservation]: + """ + Convenience function to observe log lines. + + Usage: + lines = log_text.split("\\n") + for event in observe_log_stream(lines): + print(event) + """ + observer = SystemObserver() + return observer.observe_lines(lines, source) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# TAPE WRITING - Write observations to tape for playback +# ═══════════════════════════════════════════════════════════════════════════════ + +def write_system_tape(observation: SystemObservation, tape_dir: str = "./logs") -> str: + """ + Write system observation to tape file for playback. + + Args: + observation: SystemObservation to write + tape_dir: Directory for tape files + + Returns: + Path to tape file + """ + tape_path = Path(tape_dir) + tape_path.mkdir(parents=True, exist_ok=True) + + session_id = int(time.time()) + filename = tape_path / f"system_tape_{session_id}.jsonl" + + with open(filename, "w", encoding="utf-8") as f: + # Write header + header = { + "seq": 0, + "ts": time.time(), + "type": "header", + "source": observation.source, + "adapter": observation.adapter_name, + "merkle_root": observation.merkle_root, + "chain_length": observation.chain_length, + } + f.write(json.dumps(header) + "\n") + + # Write events + for i, parsed in enumerate(observation.events, 1): + record = { + "seq": i, + "ts": parsed.timestamp, + "event": parsed.to_cascade_event(), + } + f.write(json.dumps(record, default=str) + "\n") + + return str(filename) diff --git a/cascade/system/omnidirectional_analyzer.py b/cascade/system/omnidirectional_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..23a8f04d7c0d29fd6487c3ffe06bb8aeb76ca34c --- /dev/null +++ b/cascade/system/omnidirectional_analyzer.py @@ -0,0 +1,406 @@ +""" +CASCADE Omnidirectional Analyzer +The complete circuit: Repo ↔ Dataset ↔ Logs ↔ Architecture ↔ Verification +""" + +import pandas as pd +from typing import Dict, List, Any, Tuple, Optional +from datetime import datetime +import json + +from .repo_ingester import ingest_repository +from .universal_extractor import extract_from_files +from cascade.forensics import DataForensics +from cascade.logging import get_log_manager, log + + +class OmnidirectionalAnalyzer: + """ + Complete system for omni-directional engineering analysis + Connects repositories to their operational evidence + """ + + def __init__(self): + self.logger = get_log_manager() + self.repo_data = None + self.runtime_data = None + self.analysis_results = {} + + def analyze_complete_system(self, + repo_source: str, + runtime_logs: Optional[List[str]] = None, + runtime_datasets: Optional[List[Any]] = None) -> Dict[str, Any]: + """ + Complete omni-directional analysis + + Args: + repo_source: Repository path/URL or uploaded files + runtime_logs: Actual runtime logs + runtime_datasets: Runtime datasets/files + + Returns: + Complete analysis results + """ + log("OmnidirectionalAnalyzer", "Starting complete system analysis", + context=f"Repo: {repo_source}", + impact="HIGH") + + # Step 1: Ingest repository + self.repo_data, repo_summary = self._ingest_repository(repo_source) + + # Step 2: Process runtime evidence + self.runtime_data, runtime_summary = self._process_runtime_evidence( + runtime_logs, runtime_datasets + ) + + # Step 3: Generate expected patterns from repo + expected_patterns = self._generate_expected_patterns() + + # Step 4: Extract actual patterns from runtime + actual_patterns = self._extract_actual_patterns() + + # Step 5: Compare and find convergence/divergence + comparison = self._compare_patterns(expected_patterns, actual_patterns) + + # Step 6: Generate insights + insights = self._generate_insights(comparison) + + results = { + "repository": { + "data": self.repo_data, + "summary": repo_summary + }, + "runtime": { + "data": self.runtime_data, + "summary": runtime_summary + }, + "expected_patterns": expected_patterns, + "actual_patterns": actual_patterns, + "comparison": comparison, + "insights": insights, + "timestamp": datetime.now().isoformat() + } + + self.analysis_results = results + return results + + def _ingest_repository(self, repo_source: str) -> Tuple[pd.DataFrame, Dict[str, Any]]: + """Ingest repository into analyzable format""" + log("RepoIngest", "Ingesting repository", + context=f"Source: {repo_source}", + impact="MEDIUM") + + # Handle different input types + if isinstance(repo_source, str) and repo_source.startswith(("http://", "https://", "git@")): + # Remote repository + df, summary = ingest_repository(repo_source, include_history=True) + elif isinstance(repo_source, list): + # Uploaded files + df, summary = extract_from_files(repo_source) + summary["source_type"] = "uploaded_files" + else: + # Local path + df, summary = ingest_repository(repo_source, include_history=True) + + log("RepoIngest", "Repository ingested successfully", + context=f"Files: {summary.get('total_files', 0)}, Lines: {summary.get('total_lines', 0)}", + impact="LOW") + + return df, summary + + def _process_runtime_evidence(self, + logs: Optional[List[str]], + datasets: Optional[List[Any]]) -> Tuple[pd.DataFrame, Dict[str, Any]]: + """Process runtime logs and datasets""" + log("RuntimeProcessor", "Processing runtime evidence", + context=f"Logs: {len(logs or [])}, Datasets: {len(datasets or [])}", + impact="MEDIUM") + + all_data = [] + summary = {"sources": []} + + # Process logs + if logs: + log_records = [] + for i, log_line in enumerate(logs): + log_records.append({ + "content": log_line, + "source_type": "runtime_log", + "source_file": f"log_{i}", + "line_number": i + }) + all_data.extend(log_records) + summary["sources"].append({"type": "logs", "count": len(logs)}) + + # Process datasets + if datasets: + for dataset in datasets: + # Use universal extractor + df, ds_summary = extract_from_files(dataset) + if df is not None: + df["source_type"] = "runtime_dataset" + all_data.append(df) + summary["sources"].append({"type": "dataset", "records": len(df)}) + + # Combine all runtime data + if all_data: + runtime_df = pd.concat(all_data, ignore_index=True) + summary["total_records"] = len(runtime_df) + else: + runtime_df = pd.DataFrame() + summary["total_records"] = 0 + + return runtime_df, summary + + def _generate_expected_patterns(self) -> Dict[str, Any]: + """Generate expected operational patterns from repository""" + log("PatternGenerator", "Generating expected patterns from repository", + impact="MEDIUM") + + patterns = { + "expected_functions": [], + "expected_configs": [], + "expected_dependencies": [], + "expected_operations": [], + "architecture_indicators": {} + } + + if self.repo_data is not None: + # Extract function names (expected operations) + if 'functions' in self.repo_data.columns: + all_functions = [] + for func_list in self.repo_data['functions'].dropna(): + if isinstance(func_list, str): + try: + funcs = json.loads(func_list) + all_functions.extend([f['name'] for f in funcs]) + except: + pass + patterns["expected_functions"] = list(set(all_functions)) + + # Find configuration files + config_files = self.repo_data[self.repo_data['file_type'] == 'config'] + patterns["expected_configs"] = config_files['file_path'].tolist() + + # Extract dependencies + if 'imports' in self.repo_data.columns: + all_imports = [] + for import_list in self.repo_data['imports'].dropna(): + if isinstance(import_list, str): + try: + imports = json.loads(import_list) + all_imports.extend(imports) + except: + pass + patterns["expected_dependencies"] = list(set(all_imports)) + + # Architecture indicators + patterns["architecture_indicators"] = { + "has_tests": "test" in self.repo_data['file_type'].values, + "has_ci_cd": "cicd" in self.repo_data['file_type'].values, + "main_language": self.repo_data['language'].mode().iloc[0] if not self.repo_data.empty else "unknown", + "complexity_score": self.repo_data['complexity'].sum() if 'complexity' in self.repo_data.columns else 0 + } + + return patterns + + def _extract_actual_patterns(self) -> Dict[str, Any]: + """Extract actual patterns from runtime evidence""" + log("PatternExtractor", "Extracting actual patterns from runtime", + impact="MEDIUM") + + patterns = { + "actual_operations": [], + "actual_errors": [], + "actual_dependencies": [], + "system_calls": [], + "data_flows": [] + } + + if self.runtime_data is not None and not self.runtime_data.empty: + # Run forensics on runtime data + forensics = DataForensics() + report = forensics.analyze(self.runtime_data) + + # Extract operations from ghost log + patterns["actual_operations"] = [ + op.operation for op in report.ghost_log.operations + ] + + # Extract security concerns as errors + patterns["actual_errors"] = [ + concern['issue'] for concern in report.security_concerns + ] + + # Extract tech fingerprints as dependencies + patterns["actual_dependencies"] = [ + fp.technology for fp in report.fingerprints + ] + + return patterns + + def _compare_patterns(self, expected: Dict[str, Any], actual: Dict[str, Any]) -> Dict[str, Any]: + """Compare expected vs actual patterns""" + log("PatternComparator", "Comparing expected vs actual patterns", + impact="HIGH") + + comparison = { + "convergence": {}, + "divergence": {}, + "coverage_metrics": {}, + "anomalies": [] + } + + # Function coverage + expected_funcs = set(expected.get("expected_functions", [])) + actual_funcs = set(actual.get("actual_operations", [])) + + comparison["convergence"]["functions"] = list(expected_funcs & actual_funcs) + comparison["divergence"]["missing_functions"] = list(expected_funcs - actual_funcs) + comparison["divergence"]["unexpected_functions"] = list(actual_funcs - expected_funcs) + + # Dependency analysis + expected_deps = set(expected.get("expected_dependencies", [])) + actual_deps = set(actual.get("actual_dependencies", [])) + + comparison["convergence"]["dependencies"] = list(expected_deps & actual_deps) + comparison["divergence"]["missing_dependencies"] = list(expected_deps - actual_deps) + comparison["divergence"]["unexpected_dependencies"] = list(actual_deps - expected_deps) + + # Coverage metrics + comparison["coverage_metrics"] = { + "function_coverage": len(comparison["convergence"]["functions"]) / max(len(expected_funcs), 1), + "dependency_coverage": len(comparison["convergence"]["dependencies"]) / max(len(expected_deps), 1), + "implementation_fidelity": self._calculate_fidelity(expected, actual) + } + + # Detect anomalies + comparison["anomalies"] = self._detect_anomalies(expected, actual) + + return comparison + + def _calculate_fidelity(self, expected: Dict[str, Any], actual: Dict[str, Any]) -> float: + """Calculate implementation fidelity score""" + # Simple heuristic based on convergence + total_expected = len(expected.get("expected_functions", [])) + len(expected.get("expected_dependencies", [])) + total_converged = len(self._compare_patterns(expected, actual)["convergence"].get("functions", [])) + \ + len(self._compare_patterns(expected, actual)["convergence"].get("dependencies", [])) + + return total_converged / max(total_expected, 1) + + def _detect_anomalies(self, expected: Dict[str, Any], actual: Dict[str, Any]) -> List[Dict[str, Any]]: + """Detect system anomalies""" + anomalies = [] + + # Check for unexpected operations + unexpected_ops = set(actual.get("actual_operations", [])) - set(expected.get("expected_functions", [])) + if unexpected_ops: + anomalies.append({ + "type": "unexpected_operations", + "description": f"Found {len(unexpected_ops)} operations not in repository", + "items": list(unexpected_ops)[:5] + }) + + # Check for errors + if actual.get("actual_errors"): + anomalies.append({ + "type": "runtime_errors", + "description": f"Found {len(actual['actual_errors'])} errors in runtime", + "items": actual["actual_errors"][:3] + }) + + return anomalies + + def _generate_insights(self, comparison: Dict[str, Any]) -> Dict[str, Any]: + """Generate actionable insights from comparison""" + insights = { + "overall_score": 0.0, + "recommendations": [], + "risk_assessment": {}, + "architecture_validation": {} + } + + # Calculate overall score + coverage = comparison["coverage_metrics"] + insights["overall_score"] = ( + coverage.get("function_coverage", 0) * 0.4 + + coverage.get("dependency_coverage", 0) * 0.3 + + coverage.get("implementation_fidelity", 0) * 0.3 + ) + + # Generate recommendations + if coverage["function_coverage"] < 0.8: + insights["recommendations"].append( + "Consider implementing missing functions for better coverage" + ) + + if comparison["divergence"]["unexpected_dependencies"]: + insights["recommendations"].append( + "Review unexpected dependencies - may indicate hidden requirements" + ) + + # Risk assessment + insights["risk_assessment"] = { + "complexity_risk": "high" if coverage["implementation_fidelity"] < 0.5 else "low", + "maintenance_risk": "medium" if len(comparison["divergence"]["missing_functions"]) > 5 else "low", + "security_risk": "high" if any(a["type"] == "runtime_errors" for a in comparison["anomalies"]) else "low" + } + + return insights + + def generate_report(self) -> str: + """Generate comprehensive analysis report""" + if not self.analysis_results: + return "No analysis results available. Run analyze_complete_system() first." + + results = self.analysis_results + + report = f""" +# Omnidirectional Engineering Analysis Report +Generated: {results['timestamp']} + +## Executive Summary +- Overall Implementation Fidelity: {results['insights']['overall_score']:.1%} +- Repository Files Analyzed: {results['repository']['summary'].get('total_files', 0)} +- Runtime Evidence Records: {results['runtime']['summary'].get('total_records', 0)} + +## Convergence Analysis ✅ +### Matching Elements +- Functions: {len(results['comparison']['convergence']['functions'])} +- Dependencies: {len(results['comparison']['convergence']['dependencies'])} + +## Divergence Analysis ⚠️ +### Missing from Runtime +- Functions: {len(results['comparison']['divergence']['missing_functions'])} +- Dependencies: {len(results['comparison']['divergence']['missing_dependencies'])} + +### Unexpected in Runtime +- Operations: {len(results['comparison']['divergence']['unexpected_functions'])} +- Dependencies: {len(results['comparison']['divergence']['unexpected_dependencies'])} + +## Risk Assessment +- Complexity Risk: {results['insights']['risk_assessment']['complexity_risk'].upper()} +- Maintenance Risk: {results['insights']['risk_assessment']['maintenance_risk'].upper()} +- Security Risk: {results['insights']['risk_assessment']['security_risk'].upper()} + +## Recommendations +{chr(10).join(f"- {r}" for r in results['insights']['recommendations'])} + +## Anomalies Detected +{chr(10).join(f"- {a['type']}: {a['description']}" for a in results['comparison']['anomalies'])} + +--- +*This analysis proves the connection between repository intent and runtime reality.* +""" + + return report + + +def analyze_omnidirectional(repo_source: str, + runtime_logs: Optional[List[str]] = None, + runtime_datasets: Optional[List[Any]] = None) -> Dict[str, Any]: + """ + Convenience function for complete omni-directional analysis + """ + analyzer = OmnidirectionalAnalyzer() + return analyzer.analyze_complete_system(repo_source, runtime_logs, runtime_datasets) diff --git a/cascade/system/repo_ingester.py b/cascade/system/repo_ingester.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb619776de11cb218f2ea74ec071bc19fa75af5 --- /dev/null +++ b/cascade/system/repo_ingester.py @@ -0,0 +1,344 @@ +""" +CASCADE Repository Ingester +Convert entire repositories into analyzable datasets +For Omnidirectional Engineering - closing the causation loop +""" + +import os +import json +import hashlib +from pathlib import Path +from typing import Dict, List, Any, Tuple +import pandas as pd +from datetime import datetime + +# Git operations +try: + import git + GIT_AVAILABLE = True +except ImportError: + GIT_AVAILABLE = False + print("⚠️ GitPython not installed. Install with: pip install GitPython") + +# Code analysis +try: + import ast + AST_AVAILABLE = True +except ImportError: + AST_AVAILABLE = False + + +class RepoIngester: + """ + Convert repository into structured dataset for analysis + """ + + def __init__(self): + self.repo_data = { + "files": [], + "structure": {}, + "dependencies": {}, + "commits": [], + "metrics": {} + } + + def ingest_repo(self, repo_path: str, include_history: bool = False) -> pd.DataFrame: + """ + Ingest entire repository into structured dataset + + Args: + repo_path: Path to repository (local or remote URL) + include_history: Whether to analyze git history + + Returns: + DataFrame with repository content and metadata + """ + # Handle remote URLs + if repo_path.startswith(("http://", "https://", "git@")): + repo_path = self._clone_repo(repo_path) + + # Analyze repository + self._analyze_structure(repo_path) + self._extract_files(repo_path) + + if include_history and GIT_AVAILABLE: + self._analyze_history(repo_path) + + # Convert to dataset + df = self._create_dataset() + + return df, self._generate_summary() + + def _clone_repo(self, repo_url: str) -> str: + """Clone remote repository to temporary directory""" + import tempfile + + temp_dir = tempfile.mkdtemp(prefix="cascade_repo_") + + if GIT_AVAILABLE: + git.Repo.clone_from(repo_url, temp_dir) + + return temp_dir + + def _analyze_structure(self, repo_path: str): + """Analyze repository structure""" + repo_path = Path(repo_path) + + # Build directory tree + structure = {} + for root, dirs, files in os.walk(repo_path): + # Skip hidden directories and common build dirs + dirs[:] = [d for d in dirs if not d.startswith('.') and + d not in ['node_modules', '__pycache__', 'target', 'build']] + + rel_path = Path(root).relative_to(repo_path) + + for file in files: + if not file.startswith('.'): + file_path = rel_path / file + structure[str(file_path)] = { + "type": "file", + "size": os.path.getsize(os.path.join(root, file)), + "extension": Path(file).suffix.lower() + } + + self.repo_data["structure"] = structure + + def _extract_files(self, repo_path: str): + """Extract content from all files""" + repo_path = Path(repo_path) + + for file_path in self.repo_data["structure"].keys(): + full_path = repo_path / file_path + + try: + with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + file_info = { + "path": file_path, + "content": content, + "size": len(content), + "lines": len(content.splitlines()), + "language": self._detect_language(file_path), + "type": self._classify_file(file_path, content), + "hash": hashlib.md5(content.encode()).hexdigest()[:16] + } + + # Extract code-specific info + if file_info["language"] == "python" and AST_AVAILABLE: + file_info.update(self._analyze_python(content)) + + self.repo_data["files"].append(file_info) + + except Exception as e: + self.repo_data["files"].append({ + "path": file_path, + "content": "", + "error": str(e), + "type": "binary" + }) + + def _detect_language(self, file_path: str) -> str: + """Detect programming language from file extension""" + ext = Path(file_path).suffix.lower() + + language_map = { + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".java": "java", + ".cpp": "cpp", + ".c": "c", + ".h": "c", + ".cs": "csharp", + ".go": "go", + ".rs": "rust", + ".php": "php", + ".rb": "ruby", + ".swift": "swift", + ".kt": "kotlin", + ".scala": "scala", + ".r": "r", + ".m": "matlab", + ".sh": "shell", + ".sql": "sql", + ".html": "html", + ".css": "css", + ".scss": "scss", + ".less": "less", + ".xml": "xml", + ".yaml": "yaml", + ".yml": "yaml", + ".json": "json", + ".md": "markdown", + ".txt": "text", + ".dockerfile": "dockerfile", + "dockerfile": "dockerfile" + } + + return language_map.get(ext, "unknown") + + def _classify_file(self, file_path: str, content: str) -> str: + """Classify file type based on path and content""" + path_lower = file_path.lower() + + # Configuration files + if any(x in path_lower for x in ["config", "settings", ".env", "ini", "toml", "yaml", "yml"]): + return "config" + + # Documentation + if path_lower.endswith((".md", ".rst", ".txt")): + return "documentation" + + # Tests + if "test" in path_lower or "spec" in path_lower: + return "test" + + # Dependencies + if any(x in path_lower for x in ["requirements", "package", "pipfile", "yarn", "pom.xml"]): + return "dependencies" + + # CI/CD + if any(x in path_lower for x in [".github", "gitlab", "jenkins", "dockerfile"]): + return "cicd" + + # Code + if self._detect_language(file_path) != "unknown": + return "code" + + return "other" + + def _analyze_python(self, content: str) -> Dict[str, Any]: + """Analyze Python code structure""" + try: + tree = ast.parse(content) + + functions = [] + classes = [] + imports = [] + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + functions.append({ + "name": node.name, + "line": node.lineno, + "args": [arg.arg for arg in node.args.args] + }) + elif isinstance(node, ast.ClassDef): + classes.append({ + "name": node.name, + "line": node.lineno + }) + elif isinstance(node, ast.Import): + imports.extend([alias.name for alias in node.names]) + elif isinstance(node, ast.ImportFrom): + imports.append(f"from {node.module}") + + return { + "functions": functions, + "classes": classes, + "imports": imports, + "complexity": len(functions) + len(classes) + } + + except: + return {"functions": [], "classes": [], "imports": [], "complexity": 0} + + def _analyze_history(self, repo_path: str): + """Analyze git history for patterns""" + try: + repo = git.Repo(repo_path) + + commits = [] + for commit in repo.iter_commits(max_count=100): + commits.append({ + "hash": commit.hexsha[:8], + "message": commit.message.strip(), + "author": commit.author.name, + "date": datetime.fromtimestamp(commit.committed_date).isoformat(), + "files_changed": len(commit.stats.files), + "insertions": commit.stats.total["insertions"], + "deletions": commit.stats.total["deletions"] + }) + + self.repo_data["commits"] = commits + + except Exception as e: + print(f"Could not analyze git history: {e}") + + def _create_dataset(self) -> pd.DataFrame: + """Create structured dataset from repository""" + records = [] + + for file_info in self.repo_data["files"]: + # Split content into manageable chunks + content = file_info.get("content", "") + + # Create records for analysis + record = { + # File metadata + "file_path": file_info["path"], + "file_type": file_info["type"], + "language": file_info.get("language", "unknown"), + "file_size": file_info.get("size", 0), + "file_hash": file_info.get("hash", ""), + + # Content analysis + "content": content, + "line_count": file_info.get("lines", 0), + + # Code-specific metrics + "function_count": len(file_info.get("functions", [])), + "class_count": len(file_info.get("classes", [])), + "import_count": len(file_info.get("imports", [])), + "complexity": file_info.get("complexity", 0), + + # Timestamp + "ingestion_timestamp": datetime.now().isoformat(), + + # Source type + "source_type": "repository" + } + + records.append(record) + + return pd.DataFrame(records) + + def _generate_summary(self) -> Dict[str, Any]: + """Generate repository analysis summary""" + files = self.repo_data["files"] + + summary = { + "total_files": len(files), + "languages": {}, + "file_types": {}, + "total_lines": sum(f.get("lines", 0) for f in files), + "total_functions": sum(len(f.get("functions", [])) for f in files), + "total_classes": sum(len(f.get("classes", [])) for f in files), + "commits_analyzed": len(self.repo_data.get("commits", [])) + } + + # Count languages + for f in files: + lang = f.get("language", "unknown") + summary["languages"][lang] = summary["languages"].get(lang, 0) + 1 + ftype = f.get("type", "other") + summary["file_types"][ftype] = summary["file_types"].get(ftype, 0) + 1 + + return summary + + +def ingest_repository(repo_path: str, include_history: bool = False) -> Tuple[pd.DataFrame, Dict[str, Any]]: + """ + Convenience function to ingest repository + + Args: + repo_path: Path or URL to repository + include_history: Include git history analysis + + Returns: + Tuple of (DataFrame, summary) + """ + ingester = RepoIngester() + return ingester.ingest_repo(repo_path, include_history) diff --git a/cascade/system/universal_extractor.py b/cascade/system/universal_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc77a951fd281eb3a2c86f259896ef7fb1406da --- /dev/null +++ b/cascade/system/universal_extractor.py @@ -0,0 +1,340 @@ +""" +CASCADE Universal File Extractor +Powered by Apache Tika - Professional document processing +Handles ANY file format with proper metadata and content extraction +""" + +import os +import json +import tempfile +from pathlib import Path +from typing import List, Dict, Any, Tuple, Optional +import pandas as pd +import hashlib +from datetime import datetime + +# Try to import Apache Tika (professional solution) +try: + from tika import parser + TIKA_AVAILABLE = True +except ImportError: + TIKA_AVAILABLE = False + print("⚠️ Apache Tika not installed. Install with: pip install tika") + +# Fallback extractors +try: + import fitz # PyMuPDF + PDF_AVAILABLE = True +except ImportError: + PDF_AVAILABLE = False + +try: + import pdfplumber + PDFPLUMBER_AVAILABLE = True +except ImportError: + PDFPLUMBER_AVAILABLE = False + +try: + from PyPDF2 import PdfReader + PYPDF2_AVAILABLE = True +except ImportError: + PYPDF2_AVAILABLE = False + +try: + import docx + DOCX_AVAILABLE = True +except ImportError: + DOCX_AVAILABLE = False + +try: + import openpyxl + XLSX_AVAILABLE = True +except ImportError: + XLSX_AVAILABLE = False + +try: + import pandas as pd + PANDAS_AVAILABLE = True +except ImportError: + PANDAS_AVAILABLE = False + + +class UniversalExtractor: + """ + Professional file extractor using Apache Tika + Can handle ANY file format known to man + """ + + def __init__(self): + self.session = None + if TIKA_AVAILABLE: + # Start Tika server if not running + parser.from_buffer('') + + def extract_file(self, file_path: str) -> Dict[str, Any]: + """ + Extract content and metadata from ANY file + + Returns: + Dict with: + - content: Full text content + - metadata: File metadata + - file_info: Basic file info + - error: Error message if any + """ + result = { + "content": "", + "metadata": {}, + "file_info": self._get_file_info(file_path), + "error": None + } + + try: + # Use Apache Tika if available (best option) + if TIKA_AVAILABLE: + parsed = parser.from_file(file_path, service_url='http://localhost:9998') + result["content"] = parsed.get("content", "") + result["metadata"] = parsed.get("metadata", {}) + + # Add Tika-specific metadata + if result["metadata"]: + result["metadata"]["extractor"] = "Apache Tika" + result["metadata"]["extraction_timestamp"] = datetime.now().isoformat() + + # Fallback to format-specific extractors + else: + result = self._fallback_extract(file_path, result) + + except Exception as e: + result["error"] = str(e) + # Try fallback if Tika fails + if TIKA_AVAILABLE: + result = self._fallback_extract(file_path, result) + + return result + + def _get_file_info(self, file_path: str) -> Dict[str, Any]: + """Get basic file information""" + path = Path(file_path) + + # Calculate file hash + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + + return { + "name": path.name, + "extension": path.suffix.lower(), + "size": path.stat().st_size, + "hash_md5": hash_md5.hexdigest(), + "modified": datetime.fromtimestamp(path.stat().st_mtime).isoformat() + } + + def _fallback_extract(self, file_path: str, result: Dict[str, Any]) -> Dict[str, Any]: + """Fallback extraction without Tika""" + ext = Path(file_path).suffix.lower() + + # PDF files + if ext == ".pdf": + content = self._extract_pdf(file_path) + if content: + result["content"] = content + result["metadata"]["extractor"] = "PDF fallback" + + # Office documents + elif ext in [".docx", ".doc"]: + content = self._extract_docx(file_path) + if content: + result["content"] = content + result["metadata"]["extractor"] = "DOCX fallback" + + # Excel files + elif ext in [".xlsx", ".xls"]: + content = self._extract_excel(file_path) + if content: + result["content"] = content + result["metadata"]["extractor"] = "Excel fallback" + + # Images with OCR (if available) + elif ext in [".jpg", ".jpeg", ".png", ".tiff", ".bmp"]: + content = self._extract_image(file_path) + if content: + result["content"] = content + result["metadata"]["extractor"] = "Image OCR fallback" + + # Code files + elif ext in [".py", ".js", ".java", ".cpp", ".c", ".h", ".css", ".html", ".xml", ".json", ".yaml", ".yml"]: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + result["content"] = f.read() + result["metadata"]["extractor"] = "Text reader" + + return result + + def _extract_pdf(self, file_path: str) -> Optional[str]: + """Extract text from PDF using multiple methods""" + content = "" + + # Try PyMuPDF first (best quality) + if PDF_AVAILABLE: + try: + doc = fitz.open(file_path) + for page in doc: + content += page.get_text() + "\n" + doc.close() + if content.strip(): + return content + except: + pass + + # Try pdfplumber + if PDFPLUMBER_AVAILABLE: + try: + import pdfplumber + with pdfplumber.open(file_path) as pdf: + for page in pdf.pages: + text = page.extract_text() or "" + content += text + "\n" + if content.strip(): + return content + except: + pass + + # Try PyPDF2 + if PYPDF2_AVAILABLE: + try: + reader = PdfReader(file_path) + for page in reader.pages: + text = page.extract_text() or "" + content += text + "\n" + if content.strip(): + return content + except: + pass + + return content if content.strip() else None + + def _extract_docx(self, file_path: str) -> Optional[str]: + """Extract text from DOCX""" + if DOCX_AVAILABLE: + try: + doc = docx.Document(file_path) + content = "" + for paragraph in doc.paragraphs: + content += paragraph.text + "\n" + return content if content.strip() else None + except: + pass + return None + + def _extract_excel(self, file_path: str) -> Optional[str]: + """Extract text from Excel""" + if XLSX_AVAILABLE and PANDAS_AVAILABLE: + try: + # Read all sheets + content = "" + excel_file = pd.ExcelFile(file_path) + for sheet_name in excel_file.sheet_names: + df = pd.read_excel(file_path, sheet_name=sheet_name) + content += f"\n=== Sheet: {sheet_name} ===\n" + content += df.to_string() + "\n" + return content if content.strip() else None + except: + pass + return None + + def _extract_image(self, file_path: str) -> Optional[str]: + """Extract text from image using OCR (if available)""" + # Try OCR if pytesseract is available + try: + import pytesseract + from PIL import Image + + image = Image.open(file_path) + text = pytesseract.image_to_string(image) + return text if text.strip() else None + except: + return None + + def process_folder(self, folder_files: List[Any]) -> Tuple[pd.DataFrame, Dict[str, Any]]: + """ + Process multiple files and create a unified dataset + + Args: + folder_files: List of uploaded file objects + + Returns: + Tuple of (DataFrame with all content, processing_summary) + """ + all_records = [] + file_summary = [] + + for file_obj in folder_files: + try: + # Extract from file + extracted = self.extract_file(file_obj.name) + + # Create record + record = { + "file_name": extracted["file_info"]["name"], + "file_extension": extracted["file_info"]["extension"], + "file_size": extracted["file_info"]["size"], + "file_hash": extracted["file_info"]["hash_md5"], + "content": extracted["content"], + "extractor": extracted["metadata"].get("extractor", "unknown"), + "extraction_timestamp": datetime.now().isoformat(), + "error": extracted["error"] + } + + # Add metadata as JSON + if extracted["metadata"]: + record["metadata"] = json.dumps(extracted["metadata"]) + + all_records.append(record) + + # Summary + file_summary.append({ + "file": extracted["file_info"]["name"], + "status": "success" if extracted["content"] else "failed", + "content_length": len(extracted["content"]), + "extractor": extracted["metadata"].get("extractor", "unknown"), + "error": extracted["error"] + }) + + except Exception as e: + file_summary.append({ + "file": getattr(file_obj, 'name', 'unknown'), + "status": "error", + "error": str(e) + }) + + # Create DataFrame + if all_records: + df = pd.DataFrame(all_records) + + summary = { + "total_files": len(folder_files), + "processed": len([s for s in file_summary if s["status"] == "success"]), + "failed": len([s for s in file_summary if s["status"] != "success"]), + "total_content_chars": df["content"].str.len().sum(), + "file_details": file_summary + } + + return df, summary + + return None, {"error": "No files processed", "details": file_summary} + + +# Convenience function +def extract_from_files(file_list: List[Any]) -> Tuple[pd.DataFrame, Dict[str, Any]]: + """ + Extract content from multiple files using the universal extractor + + Args: + file_list: List of file objects + + Returns: + Tuple of (DataFrame, summary) + """ + extractor = UniversalExtractor() + return extractor.process_folder(file_list) diff --git a/cascade/torch_hook.py b/cascade/torch_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..305a3638b824991d3037a1594b384d3e29080710 --- /dev/null +++ b/cascade/torch_hook.py @@ -0,0 +1,483 @@ +""" +Cascade PyTorch Hook - Deep Neural Network Instrumentation. + +This is the missing piece: direct integration with PyTorch training loops +to capture what stdout never shows: +- Per-layer gradient norms +- Weight statistics +- Activation patterns +- Attention maps +- Memory allocation + +Usage: + from cascade.torch_hook import CascadeHook + + model = YourModel() + hook = CascadeHook(model, monitor) + + # Training loop + for batch in dataloader: + loss = model(batch) + loss.backward() # Hook automatically captures gradients + optimizer.step() + + # Hook auto-logs per-layer stats to monitor +""" + +from typing import Dict, Any, Optional, List, Callable +from dataclasses import dataclass +import weakref + +try: + import torch + import torch.nn as nn + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + torch = None + nn = None + + +@dataclass +class LayerStats: + """Statistics for a single layer.""" + name: str + param_count: int + grad_norm: Optional[float] = None + grad_mean: Optional[float] = None + grad_std: Optional[float] = None + weight_norm: Optional[float] = None + weight_mean: Optional[float] = None + weight_std: Optional[float] = None + activation_norm: Optional[float] = None + activation_mean: Optional[float] = None + + +class CascadeHook: + """ + PyTorch hook for deep instrumentation. + + Captures per-layer metrics that stdout logging misses: + - Gradient flow through each layer + - Weight evolution + - Activation statistics + - Memory usage + + Example: + >>> from cascade import Monitor + >>> from cascade.torch_hook import CascadeHook + >>> + >>> monitor = Monitor() + >>> model = nn.Sequential(...) + >>> hook = CascadeHook(model, monitor) + >>> + >>> # Training happens... + >>> # Hook automatically captures: + >>> # - grad_norm/layer_0, grad_norm/layer_1, ... + >>> # - weight_norm/layer_0, ... + >>> # - activation_mean/layer_0, ... + """ + + def __init__( + self, + model: "nn.Module", + monitor: Optional[Any] = None, + track_gradients: bool = True, + track_weights: bool = True, + track_activations: bool = False, # Can be expensive + layer_filter: Optional[Callable[[str, "nn.Module"], bool]] = None, + ): + if not TORCH_AVAILABLE: + raise ImportError("PyTorch required for CascadeHook. pip install torch") + + self.model = model + self.monitor = monitor + self.track_gradients = track_gradients + self.track_weights = track_weights + self.track_activations = track_activations + self.layer_filter = layer_filter or self._default_filter + + self._handles: List[Any] = [] + self._layer_stats: Dict[str, LayerStats] = {} + self._step = 0 + + # Register hooks + self._register_hooks() + + def _default_filter(self, name: str, module: "nn.Module") -> bool: + """Default: track Linear, Conv, and Attention layers.""" + return isinstance(module, ( + nn.Linear, + nn.Conv1d, nn.Conv2d, nn.Conv3d, + nn.MultiheadAttention, + nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, + nn.Embedding, + )) + + def _register_hooks(self): + """Register forward and backward hooks on tracked layers.""" + for name, module in self.model.named_modules(): + if self.layer_filter(name, module): + # Count params + param_count = sum(p.numel() for p in module.parameters()) + self._layer_stats[name] = LayerStats(name=name, param_count=param_count) + + # Gradient hook + if self.track_gradients: + handle = module.register_full_backward_hook( + self._make_grad_hook(name) + ) + self._handles.append(handle) + + # Activation hook + if self.track_activations: + handle = module.register_forward_hook( + self._make_activation_hook(name) + ) + self._handles.append(handle) + + def _make_grad_hook(self, layer_name: str): + """Create gradient hook for a specific layer.""" + def hook(module, grad_input, grad_output): + stats = self._layer_stats[layer_name] + + # Get gradient from output + if grad_output and grad_output[0] is not None: + grad = grad_output[0] + if grad.numel() > 0: + stats.grad_norm = grad.norm().item() + stats.grad_mean = grad.mean().item() + stats.grad_std = grad.std().item() + + return hook + + def _make_activation_hook(self, layer_name: str): + """Create activation hook for a specific layer.""" + def hook(module, input, output): + stats = self._layer_stats[layer_name] + + if isinstance(output, torch.Tensor): + stats.activation_norm = output.norm().item() + stats.activation_mean = output.mean().item() + + return hook + + def capture_weights(self): + """Capture current weight statistics.""" + for name, module in self.model.named_modules(): + if name in self._layer_stats: + stats = self._layer_stats[name] + + # Get weight tensor + if hasattr(module, 'weight') and module.weight is not None: + w = module.weight.data + stats.weight_norm = w.norm().item() + stats.weight_mean = w.mean().item() + stats.weight_std = w.std().item() + + def step(self, extra_data: Optional[Dict[str, Any]] = None): + """ + Call after each training step to log metrics. + + Args: + extra_data: Additional data to include (loss, lr, etc.) + """ + self._step += 1 + + if self.track_weights: + self.capture_weights() + + # Build event data + data = {"step": self._step} + + if extra_data: + data.update(extra_data) + + # Add per-layer stats + for layer_name, stats in self._layer_stats.items(): + prefix = layer_name.replace(".", "_") + + if stats.grad_norm is not None: + data[f"grad_norm/{prefix}"] = stats.grad_norm + if stats.grad_mean is not None: + data[f"grad_mean/{prefix}"] = stats.grad_mean + if stats.weight_norm is not None: + data[f"weight_norm/{prefix}"] = stats.weight_norm + if stats.activation_mean is not None: + data[f"activation_mean/{prefix}"] = stats.activation_mean + + # Aggregate metrics + grad_norms = [s.grad_norm for s in self._layer_stats.values() if s.grad_norm is not None] + if grad_norms: + data["grad_norm_min"] = min(grad_norms) + data["grad_norm_max"] = max(grad_norms) + data["grad_norm_mean"] = sum(grad_norms) / len(grad_norms) + + weight_norms = [s.weight_norm for s in self._layer_stats.values() if s.weight_norm is not None] + if weight_norms: + data["weight_norm_total"] = sum(weight_norms) + + # Log to monitor + if self.monitor: + self.monitor.observe(data, event_type="training_step", component="torch_hook") + + return data + + def get_layer_report(self) -> Dict[str, Dict[str, Any]]: + """Get current stats for all tracked layers.""" + return { + name: { + "param_count": stats.param_count, + "grad_norm": stats.grad_norm, + "weight_norm": stats.weight_norm, + "activation_mean": stats.activation_mean, + } + for name, stats in self._layer_stats.items() + } + + def detect_issues(self) -> List[str]: + """Quick check for common issues.""" + issues = [] + + for name, stats in self._layer_stats.items(): + # Vanishing gradients + if stats.grad_norm is not None and stats.grad_norm < 1e-7: + issues.append(f"Vanishing gradient in {name}: {stats.grad_norm:.2e}") + + # Exploding gradients + if stats.grad_norm is not None and stats.grad_norm > 100: + issues.append(f"Exploding gradient in {name}: {stats.grad_norm:.2f}") + + # Dead layer (no gradient flow) + if stats.grad_norm == 0: + issues.append(f"Dead layer (zero gradient): {name}") + + # Weight explosion + if stats.weight_norm is not None and stats.weight_norm > 1000: + issues.append(f"Large weights in {name}: {stats.weight_norm:.2f}") + + return issues + + def remove(self): + """Remove all hooks.""" + for handle in self._handles: + handle.remove() + self._handles.clear() + + def __del__(self): + self.remove() + + @property + def tracked_layers(self) -> List[str]: + """List of tracked layer names.""" + return list(self._layer_stats.keys()) + + @property + def total_params(self) -> int: + """Total parameters in tracked layers.""" + return sum(s.param_count for s in self._layer_stats.values()) + + # ========================================================================= + # BRANCHING: From observation to understanding to action + # ========================================================================= + + def trace_anomaly(self, metric_name: str = "loss") -> Dict[str, Any]: + """ + BACKWARD BRANCH: When something goes wrong, which layer caused it? + + Correlates metric anomaly with per-layer gradient behavior. + Returns the likely culprit layer(s). + """ + if not self.monitor or not self.monitor.metrics: + return {"culprit": None, "reason": "No monitor data"} + + # Get anomalies from the metric + anomalies = self.monitor.metrics.recent_anomalies + if not anomalies: + return {"culprit": None, "reason": "No anomalies detected"} + + # Find layers with extreme gradients at the time of anomaly + suspects = [] + for name, stats in self._layer_stats.items(): + if stats.grad_norm is not None: + if stats.grad_norm < 1e-7: + suspects.append({"layer": name, "issue": "vanishing", "grad_norm": stats.grad_norm}) + elif stats.grad_norm > 50: + suspects.append({"layer": name, "issue": "exploding", "grad_norm": stats.grad_norm}) + + if suspects: + # Sort by severity + suspects.sort(key=lambda x: abs(x["grad_norm"]) if x["issue"] == "exploding" else -x["grad_norm"], reverse=True) + return { + "culprit": suspects[0]["layer"], + "issue": suspects[0]["issue"], + "all_suspects": suspects, + "recommendation": self._recommend_fix(suspects[0]) + } + + return {"culprit": None, "reason": "No layer anomalies found"} + + def _recommend_fix(self, suspect: Dict[str, Any]) -> str: + """Generate actionable recommendation.""" + if suspect["issue"] == "exploding": + return f"Gradient explosion in {suspect['layer']}. Try: lower LR, add gradient clipping, check for NaN inputs." + elif suspect["issue"] == "vanishing": + return f"Vanishing gradient in {suspect['layer']}. Try: residual connections, different activation, layer norm." + return "Unknown issue" + + def predict_failure(self, lookahead: int = 5) -> Dict[str, Any]: + """ + FORWARD BRANCH: Predict if training is about to fail. + + Uses gradient trends to predict explosion/vanishing before it happens. + """ + if not self.monitor or not self.monitor.metrics: + return {"risk": "unknown", "reason": "No history"} + + warnings = [] + + for name, stats in self._layer_stats.items(): + # Check gradient history via monitor + grad_key = f"grad_norm/{name.replace('.', '_')}" + series = self.monitor.metrics.get_metric(grad_key) + + if series and series.count >= 5: + trend = series.trend() + roc = series.rate_of_change() + + if trend == "rising" and roc and roc > 0: + # Project forward + projected = series.current + (roc * lookahead) + if projected > 100: + warnings.append({ + "layer": name, + "prediction": "explosion", + "current": series.current, + "projected": projected, + "steps_until": int(100 / roc) if roc > 0 else None + }) + + elif trend == "falling" and series.current < 0.001: + warnings.append({ + "layer": name, + "prediction": "vanishing", + "current": series.current, + "trend": "falling" + }) + + if warnings: + return { + "risk": "high", + "warnings": warnings, + "action": "Consider intervention now" + } + + return {"risk": "low", "warnings": [], "action": "Continue monitoring"} + + def suggest_intervention(self) -> Optional[Dict[str, Any]]: + """ + FORWARD BRANCH: Suggest specific parameter changes. + + Based on current state, recommend concrete actions. + """ + prediction = self.predict_failure() + + if prediction["risk"] != "high": + return None + + interventions = [] + + for warning in prediction["warnings"]: + if warning["prediction"] == "explosion": + interventions.append({ + "action": "reduce_lr", + "factor": 0.5, + "reason": f"Gradient explosion predicted in {warning['layer']}" + }) + interventions.append({ + "action": "add_grad_clip", + "value": 1.0, + "reason": "Prevent gradient explosion" + }) + + elif warning["prediction"] == "vanishing": + interventions.append({ + "action": "increase_lr", + "factor": 1.5, + "reason": f"Vanishing gradient in {warning['layer']}" + }) + + return { + "interventions": interventions, + "urgency": "high" if len(interventions) > 1 else "medium" + } + + def get_attention_pattern(self, layer_name: str) -> Optional[Dict[str, Any]]: + """ + DEEP BRANCH: Extract attention patterns (for transformer layers). + + Returns attention entropy, sparsity, positional bias. + """ + for name, module in self.model.named_modules(): + if name == layer_name and isinstance(module, nn.MultiheadAttention): + # Would need forward hook with attention weights + # This is a stub showing the branch exists + return { + "layer": layer_name, + "type": "attention", + "note": "Full implementation requires attention weight capture" + } + return None + + def find_dead_neurons(self, threshold: float = 0.01) -> List[Dict[str, Any]]: + """ + DEEP BRANCH: Find neurons that never activate. + + Dead neurons = wasted parameters = pruning candidates. + """ + dead = [] + + for name, stats in self._layer_stats.items(): + if stats.activation_mean is not None: + if abs(stats.activation_mean) < threshold: + dead.append({ + "layer": name, + "activation_mean": stats.activation_mean, + "recommendation": "Consider pruning or reinitializing" + }) + + return dead + + def branch_report(self) -> Dict[str, Any]: + """ + Full branch analysis: backward, forward, and deep. + """ + return { + "backward": { + "anomaly_trace": self.trace_anomaly(), + }, + "forward": { + "failure_prediction": self.predict_failure(), + "suggested_interventions": self.suggest_intervention(), + }, + "deep": { + "dead_neurons": self.find_dead_neurons(), + "layer_health": self.detect_issues(), + }, + "meta": { + "tracked_layers": len(self._layer_stats), + "total_params": self.total_params, + "step": self._step, + } + } + + +# Convenience function +def instrument(model: "nn.Module", monitor=None, **kwargs) -> CascadeHook: + """ + Quick instrumentation of a PyTorch model. + + Usage: + hook = cascade.torch_hook.instrument(model, monitor) + """ + return CascadeHook(model, monitor, **kwargs) diff --git a/cascade/viz/__init__.py b/cascade/viz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..463910463e1a79da95fa5ff02600ba6314ed87bf --- /dev/null +++ b/cascade/viz/__init__.py @@ -0,0 +1,32 @@ +""" +Cascade Tape Utilities - JSONL event storage and playback. + +This module provides utilities for working with tape files: +- load_tape_file: Load events from a JSONL tape file +- find_latest_tape: Find the most recent tape file in a directory +- list_tape_files: List all available tape files +- write_tape_event: Write an event to a tape file +- create_tape_path: Generate a timestamped tape file path +- PlaybackBuffer: Buffer for playback with timing control + +For visualization, use your preferred tools to consume the tape files +or connect to the event_queue in observe.py/listen.py. +""" + +from cascade.viz.tape import ( + load_tape_file, + find_latest_tape, + list_tape_files, + write_tape_event, + create_tape_path, + PlaybackBuffer, +) + +__all__ = [ + "load_tape_file", + "find_latest_tape", + "list_tape_files", + "write_tape_event", + "create_tape_path", + "PlaybackBuffer", +] diff --git a/cascade/viz/tape.py b/cascade/viz/tape.py new file mode 100644 index 0000000000000000000000000000000000000000..02514b38741e4b0bbc544d73c9b235cd735c0ab2 --- /dev/null +++ b/cascade/viz/tape.py @@ -0,0 +1,213 @@ +""" +CASCADE Tape Utilities - JSONL tape file handling. + +Tape files are the primary persistence format for CASCADE observations. +They use JSONL (JSON Lines) format for easy streaming and append operations. + +TAPE FILES: +- Model Observatory: logs/cascade_tape_*.jsonl +- Data Unity: logs/unity_tape_*.jsonl +- Provenance: logs/provenance_tape_*.jsonl + +Usage: + from cascade.viz.tape import load_tape_file, find_latest_tape, list_tape_files + + # Load a specific tape + events = load_tape_file("logs/cascade_tape_1234567890.jsonl") + + # Find the most recent tape + latest = find_latest_tape(tape_type="cascade") + + # List all available tapes + all_tapes = list_tape_files() +""" + +import json +from typing import List, Dict, Any, Optional +from dataclasses import dataclass, field +from pathlib import Path + + +def load_tape_file(tape_path: str) -> List[Dict[str, Any]]: + """ + Load events from a tape file (JSONL format). + + Works with all tape types: + - Model Observatory tapes: logs/cascade_tape_*.jsonl + - Data Unity tapes: logs/unity_tape_*.jsonl + - Provenance tapes: logs/provenance_tape_*.jsonl + + Args: + tape_path: Path to the .jsonl tape file + + Returns: + List of event records in chronological order + """ + events = [] + with open(tape_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + try: + events.append(json.loads(line)) + except json.JSONDecodeError: + pass # Skip malformed lines + return events + + +def write_tape_event(tape_path: str, event: Dict[str, Any]) -> None: + """ + Append a single event to a tape file. + + Args: + tape_path: Path to the .jsonl tape file + event: Event dict to append + """ + with open(tape_path, "a", encoding="utf-8") as f: + f.write(json.dumps(event) + "\n") + + +def find_latest_tape(log_dir: str = "./logs", tape_type: str = "cascade") -> Optional[str]: + """ + Find the most recent tape file of a given type. + + Args: + log_dir: Directory containing tape files + tape_type: "cascade" for Model Observatory, "unity" for Data Unity, + "provenance" for provenance tapes + + Returns: + Path to the latest tape file, or None if not found + """ + log_path = Path(log_dir) + if not log_path.exists(): + return None + + pattern = f"{tape_type}_tape_*.jsonl" + tapes = list(log_path.glob(pattern)) + + if not tapes: + return None + + # Sort by modification time (most recent first) + tapes.sort(key=lambda p: p.stat().st_mtime, reverse=True) + return str(tapes[0]) + + +def list_tape_files(log_dir: str = "./logs") -> Dict[str, List[str]]: + """ + List all available tape files by type. + + Args: + log_dir: Directory containing tape files + + Returns: + Dict with keys "cascade", "unity", and "provenance", + each containing list of tape paths sorted by recency + """ + log_path = Path(log_dir) + if not log_path.exists(): + return {"cascade": [], "unity": [], "provenance": []} + + cascade_tapes = sorted( + log_path.glob("cascade_tape_*.jsonl"), + key=lambda p: p.stat().st_mtime, + reverse=True + ) + unity_tapes = sorted( + log_path.glob("unity_tape_*.jsonl"), + key=lambda p: p.stat().st_mtime, + reverse=True + ) + provenance_tapes = sorted( + log_path.glob("provenance_tape_*.jsonl"), + key=lambda p: p.stat().st_mtime, + reverse=True + ) + + return { + "cascade": [str(p) for p in cascade_tapes], + "unity": [str(p) for p in unity_tapes], + "provenance": [str(p) for p in provenance_tapes], + } + + +def create_tape_path(log_dir: str = "./logs", tape_type: str = "cascade", session_id: Optional[int] = None) -> str: + """ + Generate a new tape file path. + + Args: + log_dir: Directory for tape files + tape_type: Type prefix ("cascade", "unity", "provenance") + session_id: Optional session ID; if None, uses current timestamp + + Returns: + Path string for the new tape file + """ + import time + log_path = Path(log_dir) + log_path.mkdir(parents=True, exist_ok=True) + + if session_id is None: + session_id = int(time.time()) + + return str(log_path / f"{tape_type}_tape_{session_id}.jsonl") + + +@dataclass +class PlaybackEvent: + """A single event in the playback buffer.""" + timestamp: float + event_type: str + data: Dict[str, Any] + + +@dataclass +class PlaybackBuffer: + """ + Buffer of events for playback/analysis. + + Useful for loading tape files and iterating through events. + """ + events: List[PlaybackEvent] = field(default_factory=list) + current_index: int = 0 + is_complete: bool = False + + def add(self, event: PlaybackEvent): + """Add an event to the buffer.""" + self.events.append(event) + + def get_events_up_to(self, index: int) -> List[PlaybackEvent]: + """Get all events up to and including the given index.""" + return self.events[:index + 1] + + def __len__(self): + return len(self.events) + + def __iter__(self): + return iter(self.events) + + @classmethod + def from_tape(cls, tape_path: str) -> "PlaybackBuffer": + """ + Create a PlaybackBuffer from a tape file. + + Args: + tape_path: Path to JSONL tape file + + Returns: + PlaybackBuffer populated with events from the tape + """ + buffer = cls() + records = load_tape_file(tape_path) + + for record in records: + event_data = record.get("event", record) + buffer.add(PlaybackEvent( + timestamp=event_data.get("timestamp", 0), + event_type=event_data.get("event_type", "unknown"), + data=event_data, + )) + + buffer.is_complete = True + return buffer diff --git a/cascade/web3_pin.py b/cascade/web3_pin.py new file mode 100644 index 0000000000000000000000000000000000000000..69b3bb8d890a43b10f0ae17013ace2aacbc7776e --- /dev/null +++ b/cascade/web3_pin.py @@ -0,0 +1,88 @@ +""" +Pin lattice to web3.storage (Filecoin-backed permanence). + +Usage: + python -m cascade.web3_pin --token YOUR_TOKEN +""" + +import os +import json +import argparse +import requests +from pathlib import Path + +WEB3_STORAGE_API = "https://api.web3.storage" + +def pin_file(filepath: Path, token: str) -> dict: + """Pin a single file to web3.storage.""" + with open(filepath, "rb") as f: + resp = requests.post( + f"{WEB3_STORAGE_API}/upload", + headers={"Authorization": f"Bearer {token}"}, + files={"file": (filepath.name, f)}, + ) + resp.raise_for_status() + return resp.json() + +def pin_lattice(token: str, lattice_dir: Path = None): + """Pin all lattice CBOR files.""" + if lattice_dir is None: + lattice_dir = Path(__file__).parent.parent / "lattice" / "ipld" + + results = {} + + for cbor_file in lattice_dir.glob("*.cbor"): + print(f"Pinning {cbor_file.name}...") + result = pin_file(cbor_file, token) + cid = result.get("cid") + results[cbor_file.stem] = cid + print(f" ✓ {cid}") + + # Verify it matches our computed CID + ipld_json = cbor_file.with_suffix(".ipld.json") + if ipld_json.exists(): + expected = json.loads(ipld_json.read_text())["cid"] + if cid == expected: + print(f" ✓ CID matches!") + else: + print(f" ⚠ CID mismatch: expected {expected}") + + return results + +def verify_availability(cid: str, timeout: int = 30) -> bool: + """Check if CID is accessible via public gateway.""" + gateways = [ + f"https://w3s.link/ipfs/{cid}", + f"https://ipfs.io/ipfs/{cid}", + f"https://dweb.link/ipfs/{cid}", + ] + + for gateway in gateways: + try: + resp = requests.head(gateway, timeout=timeout) + if resp.status_code == 200: + return True + except: + continue + return False + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Pin lattice to web3.storage") + parser.add_argument("--token", required=True, help="web3.storage API token") + parser.add_argument("--verify", action="store_true", help="Verify availability after pinning") + args = parser.parse_args() + + print("=== Pinning lattice to web3.storage ===\n") + results = pin_lattice(args.token) + + print(f"\n=== Pinned {len(results)} files ===\n") + + if args.verify: + print("Verifying availability (may take a minute)...\n") + for name, cid in results.items(): + available = verify_availability(cid) + status = "✓ LIVE" if available else "⏳ propagating" + print(f" {name}: {status}") + print(f" https://w3s.link/ipfs/{cid}") + + print("\n=== Layer 2 complete ===") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..78b2575a425d1d2ba9eb588329c02efd59791697 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,63 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "cascade-lattice" +dynamic = ["version"] +description = "Universal AI provenance layer — cryptographic receipts for every call, with HOLD inference halt protocol" +readme = "README.md" +license = "MIT" +authors = [ + { name = "Jeff Towers" } +] +keywords = [ + "ai", "ml", "provenance", "observability", "llm", "tracing", + "cryptographic", "receipts", "monitoring", "hold-protocol" +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Monitoring", +] +requires-python = ">=3.8" +dependencies = [ + "pyyaml>=6.0", + "requests>=2.28.0", +] + +[project.optional-dependencies] +torch = ["torch>=1.9.0"] +web3 = ["web3>=6.0.0"] +all = [ + "torch>=1.9.0", + "web3>=6.0.0", + "anthropic>=0.18.0", + "openai>=1.0.0", + "litellm>=1.0.0", + "huggingface-hub>=0.20.0", +] + +[project.urls] +Homepage = "https://github.com/Yufok1/cascade-lattice" +Documentation = "https://github.com/Yufok1/cascade-lattice#readme" +Repository = "https://github.com/Yufok1/cascade-lattice" +Issues = "https://github.com/Yufok1/cascade-lattice/issues" + +[project.scripts] +cascade = "cascade.cli_main:main" + +[tool.hatch.version] +path = "cascade/__init__.py" + +[tool.hatch.build.targets.wheel] +packages = ["cascade"]