Spaces:
Configuration error
Configuration error
| """ | |
| Dataset Observer | |
| The main interface for observing datasets. | |
| Provides context managers for tracking ingest, transform, and consume operations. | |
| """ | |
| import hashlib | |
| import time | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass, field | |
| from typing import Any, Callable, Dict, Generator, List, Optional, Union | |
| from .entities import ( | |
| DatasetEntity, Activity, Agent, Relationship, RelationType, | |
| ActivityType, AgentType, create_system_agent, create_model_agent, create_user_agent | |
| ) | |
| from .provenance import ProvenanceGraph | |
| from .schema import SchemaObserver, DatasetSchema, hash_content | |
| class ObservationContext: | |
| """ | |
| Context for an ongoing observation. | |
| Used within context managers to track inputs/outputs. | |
| """ | |
| activity: Activity | |
| observer: "DatasetObserver" | |
| _inputs: List[DatasetEntity] = field(default_factory=list) | |
| _outputs: List[DatasetEntity] = field(default_factory=list) | |
| def input(self, dataset, name: str = None, **kwargs) -> DatasetEntity: | |
| """ | |
| Register an input dataset. | |
| Args: | |
| dataset: HuggingFace Dataset, DatasetDict, or entity ID | |
| name: Optional name override | |
| **kwargs: Additional entity attributes | |
| Returns: | |
| The created or retrieved DatasetEntity | |
| """ | |
| # If string, assume it's an existing entity ID | |
| if isinstance(dataset, str): | |
| entity = self.observer.graph.get_entity(dataset) | |
| if entity: | |
| self._inputs.append(entity) | |
| self.activity.add_input(entity.id) | |
| self.observer.graph.link_usage(self.activity.id, entity.id) | |
| return entity | |
| else: | |
| raise ValueError(f"Entity not found: {dataset}") | |
| # Otherwise, observe the dataset | |
| entity = self.observer.observe_dataset(dataset, name=name, **kwargs) | |
| self._inputs.append(entity) | |
| self.activity.add_input(entity.id) | |
| self.observer.graph.link_usage(self.activity.id, entity.id) | |
| return entity | |
| def output(self, dataset, name: str = None, **kwargs) -> DatasetEntity: | |
| """ | |
| Register an output dataset. | |
| Args: | |
| dataset: HuggingFace Dataset, DatasetDict, or dict | |
| name: Optional name override | |
| **kwargs: Additional entity attributes | |
| Returns: | |
| The created DatasetEntity | |
| """ | |
| entity = self.observer.observe_dataset(dataset, name=name, **kwargs) | |
| self._outputs.append(entity) | |
| self.activity.add_output(entity.id) | |
| # Link generation | |
| self.observer.graph.link_generation(entity.id, self.activity.id) | |
| # Link derivation from all inputs | |
| for input_entity in self._inputs: | |
| self.observer.graph.link_derivation(entity.id, input_entity.id) | |
| return entity | |
| def inputs(self) -> List[DatasetEntity]: | |
| return self._inputs | |
| def outputs(self) -> List[DatasetEntity]: | |
| return self._outputs | |
| class DatasetObserver: | |
| """ | |
| Observer for dataset operations. | |
| Tracks: | |
| - Dataset loading (ingest) | |
| - Transformations (filter, map, join, etc.) | |
| - Consumption (training, inference) | |
| Example: | |
| observer = DatasetObserver() | |
| with observer.observe_ingest("squad") as ctx: | |
| ds = load_dataset("squad") | |
| ctx.output(ds) | |
| with observer.observe_transform("filter_english") as ctx: | |
| ctx.input(ds) | |
| filtered = ds.filter(lambda x: x["lang"] == "en") | |
| ctx.output(filtered) | |
| chain = observer.export_provenance() | |
| """ | |
| def __init__( | |
| self, | |
| name: str = "default", | |
| agent: Agent = None, | |
| ): | |
| """ | |
| Initialize observer. | |
| Args: | |
| name: Name for the provenance graph | |
| agent: Default agent for activities (defaults to graph's system agent) | |
| """ | |
| self.graph = ProvenanceGraph(name=name) | |
| self.schema_observer = SchemaObserver() | |
| # Use provided agent or the graph's default system agent | |
| if agent: | |
| self._default_agent = agent | |
| self.graph.add_agent(agent) | |
| else: | |
| # Use the graph's already-created system agent | |
| self._default_agent = self.graph._system_agent | |
| # Entity counter for unique IDs | |
| self._counter = 0 | |
| def _next_id(self, prefix: str) -> str: | |
| """Generate unique ID.""" | |
| self._counter += 1 | |
| return f"{prefix}:{int(time.time() * 1000)}:{self._counter:04d}" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DATASET OBSERVATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def observe_dataset( | |
| self, | |
| dataset, | |
| name: str = None, | |
| source_type: str = None, | |
| source_uri: str = None, | |
| version: str = None, | |
| license_id: str = None, | |
| license_url: str = None, | |
| **kwargs, | |
| ) -> DatasetEntity: | |
| """ | |
| Observe a dataset and create an entity. | |
| Args: | |
| dataset: HuggingFace Dataset, DatasetDict, DataFrame, or dict | |
| name: Name for the entity | |
| source_type: Type of source (hf_hub, local, etc.) | |
| source_uri: URI of the source | |
| version: Version string | |
| license_id: SPDX license identifier (e.g., "MIT", "CC-BY-4.0") | |
| license_url: URL to the license text | |
| **kwargs: Additional attributes | |
| Returns: | |
| DatasetEntity representing the dataset | |
| """ | |
| # Infer name if not provided | |
| if name is None: | |
| if hasattr(dataset, 'info') and hasattr(dataset.info, 'dataset_name'): | |
| name = dataset.info.dataset_name | |
| elif hasattr(dataset, 'config_name'): | |
| name = dataset.config_name | |
| else: | |
| name = f"dataset_{self._counter + 1}" | |
| # Try to extract license from HuggingFace dataset info | |
| if license_id is None and hasattr(dataset, 'info'): | |
| info = dataset.info | |
| if hasattr(info, 'license') and info.license: | |
| license_id = info.license | |
| # Observe schema | |
| schema = self._observe_schema(dataset) | |
| # Compute content hash | |
| content_hash = self._compute_content_hash(dataset) | |
| # Get record count and splits | |
| record_count, splits = self._get_counts(dataset) | |
| # Infer source | |
| if source_type is None: | |
| source_type = self._infer_source_type(dataset) | |
| # Create entity | |
| entity = DatasetEntity( | |
| id=self._next_id("entity"), | |
| name=name, | |
| content_hash=content_hash, | |
| schema_hash=schema.hash() if schema else None, | |
| version=version, | |
| source_type=source_type, | |
| source_uri=source_uri, | |
| license_id=license_id, | |
| license_url=license_url, | |
| record_count=record_count, | |
| splits=splits, | |
| attributes={ | |
| "schema": schema.to_dict() if schema else None, | |
| **kwargs, | |
| }, | |
| ) | |
| # Add to graph | |
| self.graph.add_entity(entity) | |
| return entity | |
| def register_agent(self, name: str, agent_type: str = "software", version: str = None) -> Agent: | |
| """ | |
| Register a new agent in the provenance graph. | |
| Args: | |
| name: Name of the agent | |
| agent_type: Type of agent (software, model, person, etc.) | |
| version: Optional version string | |
| Returns: | |
| The created Agent | |
| """ | |
| if agent_type == "model": | |
| agent = create_model_agent(name, version=version) | |
| elif agent_type == "system": | |
| agent = create_system_agent(name, version=version) | |
| elif agent_type == "person": | |
| agent = create_user_agent(name) | |
| else: | |
| # Default to software agent or generic | |
| try: | |
| type_enum = AgentType(agent_type) | |
| except ValueError: | |
| type_enum = AgentType.SOFTWARE | |
| agent = Agent( | |
| id=f"agent:{type_enum.value}:{name.replace(' ', '_').lower()}", | |
| agent_type=type_enum, | |
| name=name, | |
| version=version | |
| ) | |
| self.graph.add_agent(agent) | |
| return agent | |
| def _observe_schema(self, dataset) -> Optional[DatasetSchema]: | |
| """Extract schema from dataset.""" | |
| try: | |
| # HuggingFace Dataset | |
| if hasattr(dataset, 'features'): | |
| return self.schema_observer.observe_hf_dataset(dataset) | |
| # Pandas DataFrame | |
| if hasattr(dataset, 'dtypes') and hasattr(dataset, 'columns'): | |
| return self.schema_observer.observe_pandas(dataset) | |
| # Dict | |
| if isinstance(dataset, dict): | |
| # Check if it's columnar (dict of lists) | |
| if all(isinstance(v, list) for v in dataset.values()): | |
| return self.schema_observer.observe_dict(dataset) | |
| return None | |
| except Exception as e: | |
| # Don't fail observation if schema extraction fails | |
| print(f"Warning: Could not extract schema: {e}") | |
| return None | |
| def _compute_content_hash(self, dataset) -> str: | |
| """Compute content hash of dataset.""" | |
| try: | |
| return hash_content(dataset) | |
| except Exception: | |
| # Fallback to timestamp-based hash | |
| return hashlib.sha256(str(time.time()).encode()).hexdigest() | |
| def _get_counts(self, dataset) -> tuple: | |
| """Get record count and split counts.""" | |
| record_count = None | |
| splits = {} | |
| try: | |
| # HuggingFace DatasetDict | |
| if hasattr(dataset, 'keys') and hasattr(dataset, '__getitem__'): | |
| for split_name in dataset.keys(): | |
| split_ds = dataset[split_name] | |
| if hasattr(split_ds, '__len__'): | |
| splits[split_name] = len(split_ds) | |
| record_count = sum(splits.values()) if splits else None | |
| # Single dataset | |
| elif hasattr(dataset, '__len__'): | |
| record_count = len(dataset) | |
| except Exception: | |
| pass | |
| return record_count, splits | |
| def _infer_source_type(self, dataset) -> str: | |
| """Infer source type from dataset.""" | |
| # HuggingFace Dataset | |
| if hasattr(dataset, '_info'): | |
| return "hf_dataset" | |
| # Pandas | |
| if hasattr(dataset, 'dtypes'): | |
| return "pandas" | |
| # Dict | |
| if isinstance(dataset, dict): | |
| return "dict" | |
| return "unknown" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONTEXT MANAGERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def observe_ingest( | |
| self, | |
| name: str, | |
| source_uri: str = None, | |
| agent: Agent = None, | |
| **kwargs, | |
| ) -> Generator[ObservationContext, None, None]: | |
| """ | |
| Observe a dataset ingest operation. | |
| Args: | |
| name: Name of the ingest operation | |
| source_uri: URI of the data source | |
| agent: Agent performing the ingest | |
| **kwargs: Additional activity parameters | |
| Yields: | |
| ObservationContext for registering inputs/outputs | |
| Example: | |
| with observer.observe_ingest("load_squad", source_uri="hf://squad") as ctx: | |
| ds = load_dataset("squad") | |
| ctx.output(ds, name="squad") | |
| """ | |
| activity = Activity( | |
| id=self._next_id("activity"), | |
| activity_type=ActivityType.INGEST, | |
| name=name, | |
| agent_id=(agent or self._default_agent).id, | |
| parameters={"source_uri": source_uri, **kwargs}, | |
| ) | |
| activity.start() | |
| ctx = ObservationContext(activity=activity, observer=self) | |
| try: | |
| yield ctx | |
| finally: | |
| activity.end() | |
| self.graph.add_activity(activity) | |
| self.graph.link_association(activity.id, activity.agent_id) | |
| def observe_transform( | |
| self, | |
| name: str, | |
| transform_type: str = None, | |
| agent: Agent = None, | |
| **kwargs, | |
| ) -> Generator[ObservationContext, None, None]: | |
| """ | |
| Observe a dataset transformation. | |
| Args: | |
| name: Name of the transform | |
| transform_type: Type of transform (filter, map, join, etc.) | |
| agent: Agent performing the transform | |
| **kwargs: Additional activity parameters | |
| Yields: | |
| ObservationContext for registering inputs/outputs | |
| Example: | |
| with observer.observe_transform("filter_english") as ctx: | |
| ctx.input(ds) | |
| filtered = ds.filter(lambda x: x["lang"] == "en") | |
| ctx.output(filtered) | |
| """ | |
| activity = Activity( | |
| id=self._next_id("activity"), | |
| activity_type=ActivityType.TRANSFORM, | |
| name=name, | |
| agent_id=(agent or self._default_agent).id, | |
| parameters={"transform_type": transform_type, **kwargs}, | |
| ) | |
| activity.start() | |
| ctx = ObservationContext(activity=activity, observer=self) | |
| try: | |
| yield ctx | |
| finally: | |
| activity.end() | |
| self.graph.add_activity(activity) | |
| self.graph.link_association(activity.id, activity.agent_id) | |
| def observe_consume( | |
| self, | |
| name: str, | |
| model_id: str = None, | |
| consume_type: str = "train", | |
| agent: Agent = None, | |
| **kwargs, | |
| ) -> Generator[ObservationContext, None, None]: | |
| """ | |
| Observe dataset consumption (training, inference). | |
| Args: | |
| name: Name of the consumption operation | |
| model_id: ID of the model consuming the data | |
| consume_type: Type of consumption (train, evaluate, inference) | |
| agent: Agent performing the consumption | |
| **kwargs: Additional activity parameters | |
| Yields: | |
| ObservationContext for registering inputs/outputs | |
| Example: | |
| with observer.observe_consume("train_qa_model", model_id="bert-base") as ctx: | |
| ctx.input(train_ds) | |
| model = train(train_ds) | |
| # Model provenance now links to data provenance! | |
| """ | |
| # Create model agent if model_id provided | |
| if model_id and agent is None: | |
| agent = create_model_agent(model_id) | |
| self.graph.add_agent(agent) | |
| activity_type = { | |
| "train": ActivityType.TRAIN, | |
| "evaluate": ActivityType.EVALUATE, | |
| "inference": ActivityType.INFERENCE, | |
| }.get(consume_type, ActivityType.TRAIN) | |
| activity = Activity( | |
| id=self._next_id("activity"), | |
| activity_type=activity_type, | |
| name=name, | |
| agent_id=(agent or self._default_agent).id, | |
| parameters={"model_id": model_id, "consume_type": consume_type, **kwargs}, | |
| ) | |
| activity.start() | |
| ctx = ObservationContext(activity=activity, observer=self) | |
| try: | |
| yield ctx | |
| finally: | |
| activity.end() | |
| self.graph.add_activity(activity) | |
| self.graph.link_association(activity.id, activity.agent_id) | |
| def observe_entity_resolution( | |
| self, | |
| name: str, | |
| model_id: str = None, | |
| threshold: float = None, | |
| agent: Agent = None, | |
| **kwargs, | |
| ) -> Generator[ObservationContext, None, None]: | |
| """ | |
| Observe entity resolution / data unity operation. | |
| Args: | |
| name: Name of the operation | |
| model_id: Embedding model used | |
| threshold: Similarity threshold | |
| agent: Agent performing the operation | |
| **kwargs: Additional parameters | |
| Example: | |
| with observer.observe_entity_resolution("match_patients_claims") as ctx: | |
| ctx.input(patients_ds) | |
| ctx.input(claims_ds) | |
| unified = run_unity(patients_ds, claims_ds) | |
| ctx.output(unified) | |
| """ | |
| if model_id and agent is None: | |
| agent = create_model_agent(model_id) | |
| self.graph.add_agent(agent) | |
| activity = Activity( | |
| id=self._next_id("activity"), | |
| activity_type=ActivityType.ENTITY_RESOLUTION, | |
| name=name, | |
| agent_id=(agent or self._default_agent).id, | |
| parameters={ | |
| "model_id": model_id, | |
| "threshold": threshold, | |
| **kwargs, | |
| }, | |
| ) | |
| activity.start() | |
| ctx = ObservationContext(activity=activity, observer=self) | |
| try: | |
| yield ctx | |
| finally: | |
| activity.end() | |
| self.graph.add_activity(activity) | |
| self.graph.link_association(activity.id, activity.agent_id) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # EXPORT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def export_provenance(self) -> ProvenanceGraph: | |
| """Export the provenance graph.""" | |
| return self.graph | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Export observation state to dictionary.""" | |
| return { | |
| "graph": self.graph.to_dict(), | |
| "counter": self._counter, | |
| } | |
| def from_dict(cls, data: Dict[str, Any]) -> "DatasetObserver": | |
| """Load observer from dictionary.""" | |
| observer = cls() | |
| observer.graph = ProvenanceGraph.from_dict(data["graph"]) | |
| observer._counter = data.get("counter", 0) | |
| return observer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # STATISTICS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def stats(self) -> Dict[str, Any]: | |
| """Get observer statistics.""" | |
| return { | |
| "graph": self.graph.stats, | |
| "root_hash": self.graph.root_hash, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LICENSE TRACKING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_license_compatibility( | |
| self, | |
| entity_ids: List[str], | |
| target_license: str = None, | |
| ): | |
| """ | |
| Check license compatibility for deriving from entities. | |
| Args: | |
| entity_ids: List of source entity IDs | |
| target_license: Intended SPDX license for derived work | |
| Returns: | |
| LicenseCompatibility result | |
| Example: | |
| result = observer.check_license_compatibility( | |
| ["entity:123", "entity:456"], | |
| target_license="MIT" | |
| ) | |
| if not result.compatible: | |
| print(f"Issues: {result.issues}") | |
| """ | |
| from .license import check_license_compatibility | |
| sources = [] | |
| for entity_id in entity_ids: | |
| entity = self.graph.get_entity(entity_id) | |
| if entity: | |
| license_id = entity.license_id or "unknown" | |
| sources.append((entity_id, license_id)) | |
| return check_license_compatibility(sources, target_license) | |
| def get_derived_license(self, entity_ids: List[str]): | |
| """ | |
| Get the appropriate license for a work derived from entities. | |
| Args: | |
| entity_ids: List of source entity IDs | |
| Returns: | |
| SPDXLicense for the derived work | |
| """ | |
| from .license import get_derived_license | |
| licenses = [] | |
| for entity_id in entity_ids: | |
| entity = self.graph.get_entity(entity_id) | |
| if entity and entity.license_id: | |
| licenses.append(entity.license_id) | |
| return get_derived_license(licenses) if licenses else None | |
| def generate_attribution(self, entity_ids: List[str] = None) -> str: | |
| """ | |
| Generate attribution text for entities. | |
| Args: | |
| entity_ids: List of entity IDs (defaults to all entities) | |
| Returns: | |
| Markdown attribution text | |
| """ | |
| from .license import LicenseAnalyzer | |
| analyzer = LicenseAnalyzer() | |
| if entity_ids is None: | |
| entities = self.graph.list_entities() | |
| else: | |
| entities = [ | |
| self.graph.get_entity(eid) for eid in entity_ids | |
| if self.graph.get_entity(eid) | |
| ] | |
| sources = [ | |
| (e.id, e.license_id or "unknown", e.name) | |
| for e in entities | |
| ] | |
| return analyzer.generate_attribution(sources) | |
| def __repr__(self) -> str: | |
| return f"DatasetObserver({self.graph})" | |