Cascade / cascade /data /observer.py
tostido's picture
Initial commit - cascade-lattice 0.5.4
77bcbf1
"""
Dataset Observer
The main interface for observing datasets.
Provides context managers for tracking ingest, transform, and consume operations.
"""
import hashlib
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Generator, List, Optional, Union
from .entities import (
DatasetEntity, Activity, Agent, Relationship, RelationType,
ActivityType, AgentType, create_system_agent, create_model_agent, create_user_agent
)
from .provenance import ProvenanceGraph
from .schema import SchemaObserver, DatasetSchema, hash_content
@dataclass
class ObservationContext:
"""
Context for an ongoing observation.
Used within context managers to track inputs/outputs.
"""
activity: Activity
observer: "DatasetObserver"
_inputs: List[DatasetEntity] = field(default_factory=list)
_outputs: List[DatasetEntity] = field(default_factory=list)
def input(self, dataset, name: str = None, **kwargs) -> DatasetEntity:
"""
Register an input dataset.
Args:
dataset: HuggingFace Dataset, DatasetDict, or entity ID
name: Optional name override
**kwargs: Additional entity attributes
Returns:
The created or retrieved DatasetEntity
"""
# If string, assume it's an existing entity ID
if isinstance(dataset, str):
entity = self.observer.graph.get_entity(dataset)
if entity:
self._inputs.append(entity)
self.activity.add_input(entity.id)
self.observer.graph.link_usage(self.activity.id, entity.id)
return entity
else:
raise ValueError(f"Entity not found: {dataset}")
# Otherwise, observe the dataset
entity = self.observer.observe_dataset(dataset, name=name, **kwargs)
self._inputs.append(entity)
self.activity.add_input(entity.id)
self.observer.graph.link_usage(self.activity.id, entity.id)
return entity
def output(self, dataset, name: str = None, **kwargs) -> DatasetEntity:
"""
Register an output dataset.
Args:
dataset: HuggingFace Dataset, DatasetDict, or dict
name: Optional name override
**kwargs: Additional entity attributes
Returns:
The created DatasetEntity
"""
entity = self.observer.observe_dataset(dataset, name=name, **kwargs)
self._outputs.append(entity)
self.activity.add_output(entity.id)
# Link generation
self.observer.graph.link_generation(entity.id, self.activity.id)
# Link derivation from all inputs
for input_entity in self._inputs:
self.observer.graph.link_derivation(entity.id, input_entity.id)
return entity
@property
def inputs(self) -> List[DatasetEntity]:
return self._inputs
@property
def outputs(self) -> List[DatasetEntity]:
return self._outputs
class DatasetObserver:
"""
Observer for dataset operations.
Tracks:
- Dataset loading (ingest)
- Transformations (filter, map, join, etc.)
- Consumption (training, inference)
Example:
observer = DatasetObserver()
with observer.observe_ingest("squad") as ctx:
ds = load_dataset("squad")
ctx.output(ds)
with observer.observe_transform("filter_english") as ctx:
ctx.input(ds)
filtered = ds.filter(lambda x: x["lang"] == "en")
ctx.output(filtered)
chain = observer.export_provenance()
"""
def __init__(
self,
name: str = "default",
agent: Agent = None,
):
"""
Initialize observer.
Args:
name: Name for the provenance graph
agent: Default agent for activities (defaults to graph's system agent)
"""
self.graph = ProvenanceGraph(name=name)
self.schema_observer = SchemaObserver()
# Use provided agent or the graph's default system agent
if agent:
self._default_agent = agent
self.graph.add_agent(agent)
else:
# Use the graph's already-created system agent
self._default_agent = self.graph._system_agent
# Entity counter for unique IDs
self._counter = 0
def _next_id(self, prefix: str) -> str:
"""Generate unique ID."""
self._counter += 1
return f"{prefix}:{int(time.time() * 1000)}:{self._counter:04d}"
# ═══════════════════════════════════════════════════════════════════════════
# DATASET OBSERVATION
# ═══════════════════════════════════════════════════════════════════════════
def observe_dataset(
self,
dataset,
name: str = None,
source_type: str = None,
source_uri: str = None,
version: str = None,
license_id: str = None,
license_url: str = None,
**kwargs,
) -> DatasetEntity:
"""
Observe a dataset and create an entity.
Args:
dataset: HuggingFace Dataset, DatasetDict, DataFrame, or dict
name: Name for the entity
source_type: Type of source (hf_hub, local, etc.)
source_uri: URI of the source
version: Version string
license_id: SPDX license identifier (e.g., "MIT", "CC-BY-4.0")
license_url: URL to the license text
**kwargs: Additional attributes
Returns:
DatasetEntity representing the dataset
"""
# Infer name if not provided
if name is None:
if hasattr(dataset, 'info') and hasattr(dataset.info, 'dataset_name'):
name = dataset.info.dataset_name
elif hasattr(dataset, 'config_name'):
name = dataset.config_name
else:
name = f"dataset_{self._counter + 1}"
# Try to extract license from HuggingFace dataset info
if license_id is None and hasattr(dataset, 'info'):
info = dataset.info
if hasattr(info, 'license') and info.license:
license_id = info.license
# Observe schema
schema = self._observe_schema(dataset)
# Compute content hash
content_hash = self._compute_content_hash(dataset)
# Get record count and splits
record_count, splits = self._get_counts(dataset)
# Infer source
if source_type is None:
source_type = self._infer_source_type(dataset)
# Create entity
entity = DatasetEntity(
id=self._next_id("entity"),
name=name,
content_hash=content_hash,
schema_hash=schema.hash() if schema else None,
version=version,
source_type=source_type,
source_uri=source_uri,
license_id=license_id,
license_url=license_url,
record_count=record_count,
splits=splits,
attributes={
"schema": schema.to_dict() if schema else None,
**kwargs,
},
)
# Add to graph
self.graph.add_entity(entity)
return entity
def register_agent(self, name: str, agent_type: str = "software", version: str = None) -> Agent:
"""
Register a new agent in the provenance graph.
Args:
name: Name of the agent
agent_type: Type of agent (software, model, person, etc.)
version: Optional version string
Returns:
The created Agent
"""
if agent_type == "model":
agent = create_model_agent(name, version=version)
elif agent_type == "system":
agent = create_system_agent(name, version=version)
elif agent_type == "person":
agent = create_user_agent(name)
else:
# Default to software agent or generic
try:
type_enum = AgentType(agent_type)
except ValueError:
type_enum = AgentType.SOFTWARE
agent = Agent(
id=f"agent:{type_enum.value}:{name.replace(' ', '_').lower()}",
agent_type=type_enum,
name=name,
version=version
)
self.graph.add_agent(agent)
return agent
def _observe_schema(self, dataset) -> Optional[DatasetSchema]:
"""Extract schema from dataset."""
try:
# HuggingFace Dataset
if hasattr(dataset, 'features'):
return self.schema_observer.observe_hf_dataset(dataset)
# Pandas DataFrame
if hasattr(dataset, 'dtypes') and hasattr(dataset, 'columns'):
return self.schema_observer.observe_pandas(dataset)
# Dict
if isinstance(dataset, dict):
# Check if it's columnar (dict of lists)
if all(isinstance(v, list) for v in dataset.values()):
return self.schema_observer.observe_dict(dataset)
return None
except Exception as e:
# Don't fail observation if schema extraction fails
print(f"Warning: Could not extract schema: {e}")
return None
def _compute_content_hash(self, dataset) -> str:
"""Compute content hash of dataset."""
try:
return hash_content(dataset)
except Exception:
# Fallback to timestamp-based hash
return hashlib.sha256(str(time.time()).encode()).hexdigest()
def _get_counts(self, dataset) -> tuple:
"""Get record count and split counts."""
record_count = None
splits = {}
try:
# HuggingFace DatasetDict
if hasattr(dataset, 'keys') and hasattr(dataset, '__getitem__'):
for split_name in dataset.keys():
split_ds = dataset[split_name]
if hasattr(split_ds, '__len__'):
splits[split_name] = len(split_ds)
record_count = sum(splits.values()) if splits else None
# Single dataset
elif hasattr(dataset, '__len__'):
record_count = len(dataset)
except Exception:
pass
return record_count, splits
def _infer_source_type(self, dataset) -> str:
"""Infer source type from dataset."""
# HuggingFace Dataset
if hasattr(dataset, '_info'):
return "hf_dataset"
# Pandas
if hasattr(dataset, 'dtypes'):
return "pandas"
# Dict
if isinstance(dataset, dict):
return "dict"
return "unknown"
# ═══════════════════════════════════════════════════════════════════════════
# CONTEXT MANAGERS
# ═══════════════════════════════════════════════════════════════════════════
@contextmanager
def observe_ingest(
self,
name: str,
source_uri: str = None,
agent: Agent = None,
**kwargs,
) -> Generator[ObservationContext, None, None]:
"""
Observe a dataset ingest operation.
Args:
name: Name of the ingest operation
source_uri: URI of the data source
agent: Agent performing the ingest
**kwargs: Additional activity parameters
Yields:
ObservationContext for registering inputs/outputs
Example:
with observer.observe_ingest("load_squad", source_uri="hf://squad") as ctx:
ds = load_dataset("squad")
ctx.output(ds, name="squad")
"""
activity = Activity(
id=self._next_id("activity"),
activity_type=ActivityType.INGEST,
name=name,
agent_id=(agent or self._default_agent).id,
parameters={"source_uri": source_uri, **kwargs},
)
activity.start()
ctx = ObservationContext(activity=activity, observer=self)
try:
yield ctx
finally:
activity.end()
self.graph.add_activity(activity)
self.graph.link_association(activity.id, activity.agent_id)
@contextmanager
def observe_transform(
self,
name: str,
transform_type: str = None,
agent: Agent = None,
**kwargs,
) -> Generator[ObservationContext, None, None]:
"""
Observe a dataset transformation.
Args:
name: Name of the transform
transform_type: Type of transform (filter, map, join, etc.)
agent: Agent performing the transform
**kwargs: Additional activity parameters
Yields:
ObservationContext for registering inputs/outputs
Example:
with observer.observe_transform("filter_english") as ctx:
ctx.input(ds)
filtered = ds.filter(lambda x: x["lang"] == "en")
ctx.output(filtered)
"""
activity = Activity(
id=self._next_id("activity"),
activity_type=ActivityType.TRANSFORM,
name=name,
agent_id=(agent or self._default_agent).id,
parameters={"transform_type": transform_type, **kwargs},
)
activity.start()
ctx = ObservationContext(activity=activity, observer=self)
try:
yield ctx
finally:
activity.end()
self.graph.add_activity(activity)
self.graph.link_association(activity.id, activity.agent_id)
@contextmanager
def observe_consume(
self,
name: str,
model_id: str = None,
consume_type: str = "train",
agent: Agent = None,
**kwargs,
) -> Generator[ObservationContext, None, None]:
"""
Observe dataset consumption (training, inference).
Args:
name: Name of the consumption operation
model_id: ID of the model consuming the data
consume_type: Type of consumption (train, evaluate, inference)
agent: Agent performing the consumption
**kwargs: Additional activity parameters
Yields:
ObservationContext for registering inputs/outputs
Example:
with observer.observe_consume("train_qa_model", model_id="bert-base") as ctx:
ctx.input(train_ds)
model = train(train_ds)
# Model provenance now links to data provenance!
"""
# Create model agent if model_id provided
if model_id and agent is None:
agent = create_model_agent(model_id)
self.graph.add_agent(agent)
activity_type = {
"train": ActivityType.TRAIN,
"evaluate": ActivityType.EVALUATE,
"inference": ActivityType.INFERENCE,
}.get(consume_type, ActivityType.TRAIN)
activity = Activity(
id=self._next_id("activity"),
activity_type=activity_type,
name=name,
agent_id=(agent or self._default_agent).id,
parameters={"model_id": model_id, "consume_type": consume_type, **kwargs},
)
activity.start()
ctx = ObservationContext(activity=activity, observer=self)
try:
yield ctx
finally:
activity.end()
self.graph.add_activity(activity)
self.graph.link_association(activity.id, activity.agent_id)
@contextmanager
def observe_entity_resolution(
self,
name: str,
model_id: str = None,
threshold: float = None,
agent: Agent = None,
**kwargs,
) -> Generator[ObservationContext, None, None]:
"""
Observe entity resolution / data unity operation.
Args:
name: Name of the operation
model_id: Embedding model used
threshold: Similarity threshold
agent: Agent performing the operation
**kwargs: Additional parameters
Example:
with observer.observe_entity_resolution("match_patients_claims") as ctx:
ctx.input(patients_ds)
ctx.input(claims_ds)
unified = run_unity(patients_ds, claims_ds)
ctx.output(unified)
"""
if model_id and agent is None:
agent = create_model_agent(model_id)
self.graph.add_agent(agent)
activity = Activity(
id=self._next_id("activity"),
activity_type=ActivityType.ENTITY_RESOLUTION,
name=name,
agent_id=(agent or self._default_agent).id,
parameters={
"model_id": model_id,
"threshold": threshold,
**kwargs,
},
)
activity.start()
ctx = ObservationContext(activity=activity, observer=self)
try:
yield ctx
finally:
activity.end()
self.graph.add_activity(activity)
self.graph.link_association(activity.id, activity.agent_id)
# ═══════════════════════════════════════════════════════════════════════════
# EXPORT
# ═══════════════════════════════════════════════════════════════════════════
def export_provenance(self) -> ProvenanceGraph:
"""Export the provenance graph."""
return self.graph
def to_dict(self) -> Dict[str, Any]:
"""Export observation state to dictionary."""
return {
"graph": self.graph.to_dict(),
"counter": self._counter,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DatasetObserver":
"""Load observer from dictionary."""
observer = cls()
observer.graph = ProvenanceGraph.from_dict(data["graph"])
observer._counter = data.get("counter", 0)
return observer
# ═══════════════════════════════════════════════════════════════════════════
# STATISTICS
# ═══════════════════════════════════════════════════════════════════════════
@property
def stats(self) -> Dict[str, Any]:
"""Get observer statistics."""
return {
"graph": self.graph.stats,
"root_hash": self.graph.root_hash,
}
# ═══════════════════════════════════════════════════════════════════════════
# LICENSE TRACKING
# ═══════════════════════════════════════════════════════════════════════════
def check_license_compatibility(
self,
entity_ids: List[str],
target_license: str = None,
):
"""
Check license compatibility for deriving from entities.
Args:
entity_ids: List of source entity IDs
target_license: Intended SPDX license for derived work
Returns:
LicenseCompatibility result
Example:
result = observer.check_license_compatibility(
["entity:123", "entity:456"],
target_license="MIT"
)
if not result.compatible:
print(f"Issues: {result.issues}")
"""
from .license import check_license_compatibility
sources = []
for entity_id in entity_ids:
entity = self.graph.get_entity(entity_id)
if entity:
license_id = entity.license_id or "unknown"
sources.append((entity_id, license_id))
return check_license_compatibility(sources, target_license)
def get_derived_license(self, entity_ids: List[str]):
"""
Get the appropriate license for a work derived from entities.
Args:
entity_ids: List of source entity IDs
Returns:
SPDXLicense for the derived work
"""
from .license import get_derived_license
licenses = []
for entity_id in entity_ids:
entity = self.graph.get_entity(entity_id)
if entity and entity.license_id:
licenses.append(entity.license_id)
return get_derived_license(licenses) if licenses else None
def generate_attribution(self, entity_ids: List[str] = None) -> str:
"""
Generate attribution text for entities.
Args:
entity_ids: List of entity IDs (defaults to all entities)
Returns:
Markdown attribution text
"""
from .license import LicenseAnalyzer
analyzer = LicenseAnalyzer()
if entity_ids is None:
entities = self.graph.list_entities()
else:
entities = [
self.graph.get_entity(eid) for eid in entity_ids
if self.graph.get_entity(eid)
]
sources = [
(e.id, e.license_id or "unknown", e.name)
for e in entities
]
return analyzer.generate_attribution(sources)
def __repr__(self) -> str:
return f"DatasetObserver({self.graph})"