Spaces:
Configuration error
Configuration error
| """ | |
| Schema Observer | |
| Observes and hashes dataset schemas/features. | |
| Works with HuggingFace datasets Features, Pandas DataFrames, and raw dicts. | |
| """ | |
| import hashlib | |
| import json | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional, Union | |
| class FieldSchema: | |
| """Schema for a single field/column.""" | |
| name: str | |
| dtype: str # Normalized type name | |
| # Type details | |
| nullable: bool = True | |
| is_list: bool = False | |
| list_inner_type: Optional[str] = None | |
| # For ClassLabel | |
| is_categorical: bool = False | |
| categories: Optional[List[str]] = None | |
| num_categories: Optional[int] = None | |
| # For nested structures | |
| nested_fields: Optional[Dict[str, "FieldSchema"]] = None | |
| # For arrays/tensors | |
| shape: Optional[tuple] = None | |
| # Constraints | |
| min_value: Optional[float] = None | |
| max_value: Optional[float] = None | |
| pattern: Optional[str] = None # Regex for strings | |
| # Metadata | |
| description: Optional[str] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| result = { | |
| "name": self.name, | |
| "dtype": self.dtype, | |
| "nullable": self.nullable, | |
| } | |
| if self.is_list: | |
| result["is_list"] = True | |
| result["list_inner_type"] = self.list_inner_type | |
| if self.is_categorical: | |
| result["is_categorical"] = True | |
| result["categories"] = self.categories | |
| result["num_categories"] = self.num_categories | |
| if self.nested_fields: | |
| result["nested_fields"] = { | |
| k: v.to_dict() for k, v in self.nested_fields.items() | |
| } | |
| if self.shape: | |
| result["shape"] = self.shape | |
| if self.description: | |
| result["description"] = self.description | |
| return result | |
| def hash(self) -> str: | |
| """Hash this field's structure.""" | |
| content = json.dumps(self.to_dict(), sort_keys=True) | |
| return hashlib.sha256(content.encode()).hexdigest()[:16] | |
| class DatasetSchema: | |
| """Complete schema for a dataset.""" | |
| fields: Dict[str, FieldSchema] = field(default_factory=dict) | |
| # Dataset-level metadata | |
| primary_key: Optional[List[str]] = None | |
| foreign_keys: Dict[str, str] = field(default_factory=dict) # field → target | |
| # Source info | |
| source_format: Optional[str] = None # arrow, parquet, csv, etc. | |
| def add_field(self, field_schema: FieldSchema): | |
| """Add a field to the schema.""" | |
| self.fields[field_schema.name] = field_schema | |
| def field_names(self) -> List[str]: | |
| return list(self.fields.keys()) | |
| def num_fields(self) -> int: | |
| return len(self.fields) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "fields": {k: v.to_dict() for k, v in self.fields.items()}, | |
| "primary_key": self.primary_key, | |
| "foreign_keys": self.foreign_keys, | |
| "source_format": self.source_format, | |
| } | |
| def hash(self) -> str: | |
| """Compute schema hash - identifies structure regardless of content.""" | |
| # Sort fields for deterministic hashing | |
| ordered_fields = sorted(self.fields.keys()) | |
| content = json.dumps({ | |
| "fields": [self.fields[k].to_dict() for k in ordered_fields], | |
| "primary_key": self.primary_key, | |
| }, sort_keys=True) | |
| return hashlib.sha256(content.encode()).hexdigest() | |
| def diff(self, other: "DatasetSchema") -> Dict[str, Any]: | |
| """Compare two schemas and return differences.""" | |
| added = set(other.field_names) - set(self.field_names) | |
| removed = set(self.field_names) - set(other.field_names) | |
| modified = {} | |
| for name in set(self.field_names) & set(other.field_names): | |
| if self.fields[name].hash() != other.fields[name].hash(): | |
| modified[name] = { | |
| "old": self.fields[name].to_dict(), | |
| "new": other.fields[name].to_dict(), | |
| } | |
| return { | |
| "added": list(added), | |
| "removed": list(removed), | |
| "modified": modified, | |
| "compatible": len(removed) == 0 and len(modified) == 0, | |
| } | |
| class SchemaObserver: | |
| """ | |
| Observes and extracts schemas from various data sources. | |
| """ | |
| # Type mapping from various sources to normalized types | |
| TYPE_MAP = { | |
| # Python types | |
| "str": "string", | |
| "int": "int64", | |
| "float": "float64", | |
| "bool": "bool", | |
| "bytes": "binary", | |
| # NumPy types | |
| "int8": "int8", | |
| "int16": "int16", | |
| "int32": "int32", | |
| "int64": "int64", | |
| "uint8": "uint8", | |
| "uint16": "uint16", | |
| "uint32": "uint32", | |
| "uint64": "uint64", | |
| "float16": "float16", | |
| "float32": "float32", | |
| "float64": "float64", | |
| # Arrow types | |
| "string": "string", | |
| "large_string": "string", | |
| "binary": "binary", | |
| "large_binary": "binary", | |
| # HuggingFace special types | |
| "Image": "image", | |
| "Audio": "audio", | |
| "ClassLabel": "categorical", | |
| } | |
| def observe_hf_dataset(self, dataset) -> DatasetSchema: | |
| """ | |
| Extract schema from HuggingFace Dataset. | |
| Args: | |
| dataset: A HuggingFace datasets.Dataset or DatasetDict | |
| Returns: | |
| DatasetSchema with all fields | |
| """ | |
| schema = DatasetSchema(source_format="arrow") | |
| # Get features (works for both Dataset and DatasetDict) | |
| if hasattr(dataset, 'features'): | |
| features = dataset.features | |
| elif hasattr(dataset, '__iter__'): | |
| # DatasetDict - get features from first split | |
| first_split = next(iter(dataset.values())) | |
| features = first_split.features | |
| else: | |
| raise ValueError(f"Cannot extract features from {type(dataset)}") | |
| # Parse each feature | |
| for name, feature in features.items(): | |
| field_schema = self._parse_hf_feature(name, feature) | |
| schema.add_field(field_schema) | |
| return schema | |
| def _parse_hf_feature(self, name: str, feature) -> FieldSchema: | |
| """Parse a HuggingFace Feature into FieldSchema.""" | |
| # Import here to avoid hard dependency | |
| try: | |
| from datasets import ( | |
| Value, ClassLabel, Sequence, | |
| Array2D, Array3D, Array4D, Array5D, | |
| Image, Audio | |
| ) | |
| except ImportError: | |
| # Fallback for when datasets not installed | |
| return FieldSchema(name=name, dtype="unknown") | |
| # Value type (primitives) | |
| if isinstance(feature, Value): | |
| return FieldSchema( | |
| name=name, | |
| dtype=self.TYPE_MAP.get(feature.dtype, feature.dtype), | |
| ) | |
| # ClassLabel (categorical) | |
| if isinstance(feature, ClassLabel): | |
| return FieldSchema( | |
| name=name, | |
| dtype="categorical", | |
| is_categorical=True, | |
| categories=feature.names, | |
| num_categories=feature.num_classes, | |
| ) | |
| # Sequence (list) | |
| if isinstance(feature, Sequence): | |
| inner = self._parse_hf_feature(f"{name}_inner", feature.feature) | |
| return FieldSchema( | |
| name=name, | |
| dtype="list", | |
| is_list=True, | |
| list_inner_type=inner.dtype, | |
| ) | |
| # Arrays | |
| if isinstance(feature, (Array2D, Array3D, Array4D, Array5D)): | |
| return FieldSchema( | |
| name=name, | |
| dtype=self.TYPE_MAP.get(feature.dtype, feature.dtype), | |
| shape=feature.shape, | |
| ) | |
| # Image | |
| if isinstance(feature, Image): | |
| return FieldSchema( | |
| name=name, | |
| dtype="image", | |
| ) | |
| # Audio | |
| if isinstance(feature, Audio): | |
| return FieldSchema( | |
| name=name, | |
| dtype="audio", | |
| ) | |
| # Dict/nested structure | |
| if isinstance(feature, dict): | |
| nested = {} | |
| for k, v in feature.items(): | |
| nested[k] = self._parse_hf_feature(k, v) | |
| return FieldSchema( | |
| name=name, | |
| dtype="struct", | |
| nested_fields=nested, | |
| ) | |
| # Fallback | |
| return FieldSchema( | |
| name=name, | |
| dtype=str(type(feature).__name__), | |
| ) | |
| def observe_pandas(self, df) -> DatasetSchema: | |
| """ | |
| Extract schema from Pandas DataFrame. | |
| Args: | |
| df: A pandas DataFrame | |
| Returns: | |
| DatasetSchema with all fields | |
| """ | |
| schema = DatasetSchema(source_format="pandas") | |
| for col in df.columns: | |
| dtype = str(df[col].dtype) | |
| normalized = self.TYPE_MAP.get(dtype, dtype) | |
| # Check for categorical | |
| if dtype == "category": | |
| schema.add_field(FieldSchema( | |
| name=col, | |
| dtype="categorical", | |
| is_categorical=True, | |
| categories=list(df[col].cat.categories), | |
| num_categories=len(df[col].cat.categories), | |
| )) | |
| else: | |
| schema.add_field(FieldSchema( | |
| name=col, | |
| dtype=normalized, | |
| nullable=df[col].isna().any(), | |
| )) | |
| return schema | |
| def observe_dict(self, data: Dict[str, Any], sample_size: int = 100) -> DatasetSchema: | |
| """ | |
| Extract schema from a dict of lists (columnar format). | |
| Args: | |
| data: Dict mapping column names to lists of values | |
| sample_size: Number of values to sample for type inference | |
| Returns: | |
| DatasetSchema with all fields | |
| """ | |
| schema = DatasetSchema(source_format="dict") | |
| for col, values in data.items(): | |
| if not values: | |
| schema.add_field(FieldSchema(name=col, dtype="unknown")) | |
| continue | |
| # Sample values for type inference | |
| sample = values[:sample_size] | |
| types = set(type(v).__name__ for v in sample if v is not None) | |
| # Determine type | |
| if len(types) == 0: | |
| dtype = "null" | |
| elif len(types) == 1: | |
| dtype = self.TYPE_MAP.get(types.pop(), "unknown") | |
| else: | |
| dtype = "mixed" | |
| # Check for nulls | |
| nullable = any(v is None for v in sample) | |
| schema.add_field(FieldSchema( | |
| name=col, | |
| dtype=dtype, | |
| nullable=nullable, | |
| )) | |
| return schema | |
| def observe_arrow(self, table) -> DatasetSchema: | |
| """ | |
| Extract schema from PyArrow Table. | |
| Args: | |
| table: A pyarrow.Table | |
| Returns: | |
| DatasetSchema with all fields | |
| """ | |
| schema = DatasetSchema(source_format="arrow") | |
| for field in table.schema: | |
| dtype = str(field.type) | |
| normalized = self.TYPE_MAP.get(dtype, dtype) | |
| schema.add_field(FieldSchema( | |
| name=field.name, | |
| dtype=normalized, | |
| nullable=field.nullable, | |
| )) | |
| return schema | |
| def hash_content(data, sample_size: int = 10000) -> str: | |
| """ | |
| Compute content hash of dataset. | |
| For large datasets, samples rows for efficiency. | |
| """ | |
| hasher = hashlib.sha256() | |
| # Handle dict first (dict also has __iter__ and __len__) | |
| if isinstance(data, dict): | |
| content = json.dumps(data, sort_keys=True, default=str) | |
| hasher.update(content.encode()) | |
| # Handle list | |
| elif isinstance(data, list): | |
| for item in data[:sample_size]: | |
| item_str = json.dumps(item, sort_keys=True, default=str) | |
| hasher.update(item_str.encode()) | |
| # Handle HuggingFace Dataset or other iterables with __len__ | |
| elif hasattr(data, '__iter__') and hasattr(data, '__len__'): | |
| # Sample if large | |
| n = len(data) | |
| if n > sample_size: | |
| import random | |
| indices = sorted(random.sample(range(n), sample_size)) | |
| sample = [data[i] for i in indices] | |
| else: | |
| sample = list(data) | |
| for row in sample: | |
| row_str = json.dumps(row, sort_keys=True, default=str) | |
| hasher.update(row_str.encode()) | |
| return hasher.hexdigest() | |