Cascade / cascade /data /schema.py
tostido's picture
Initial commit - cascade-lattice 0.5.4
77bcbf1
"""
Schema Observer
Observes and hashes dataset schemas/features.
Works with HuggingFace datasets Features, Pandas DataFrames, and raw dicts.
"""
import hashlib
import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class FieldSchema:
"""Schema for a single field/column."""
name: str
dtype: str # Normalized type name
# Type details
nullable: bool = True
is_list: bool = False
list_inner_type: Optional[str] = None
# For ClassLabel
is_categorical: bool = False
categories: Optional[List[str]] = None
num_categories: Optional[int] = None
# For nested structures
nested_fields: Optional[Dict[str, "FieldSchema"]] = None
# For arrays/tensors
shape: Optional[tuple] = None
# Constraints
min_value: Optional[float] = None
max_value: Optional[float] = None
pattern: Optional[str] = None # Regex for strings
# Metadata
description: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
result = {
"name": self.name,
"dtype": self.dtype,
"nullable": self.nullable,
}
if self.is_list:
result["is_list"] = True
result["list_inner_type"] = self.list_inner_type
if self.is_categorical:
result["is_categorical"] = True
result["categories"] = self.categories
result["num_categories"] = self.num_categories
if self.nested_fields:
result["nested_fields"] = {
k: v.to_dict() for k, v in self.nested_fields.items()
}
if self.shape:
result["shape"] = self.shape
if self.description:
result["description"] = self.description
return result
def hash(self) -> str:
"""Hash this field's structure."""
content = json.dumps(self.to_dict(), sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()[:16]
@dataclass
class DatasetSchema:
"""Complete schema for a dataset."""
fields: Dict[str, FieldSchema] = field(default_factory=dict)
# Dataset-level metadata
primary_key: Optional[List[str]] = None
foreign_keys: Dict[str, str] = field(default_factory=dict) # field → target
# Source info
source_format: Optional[str] = None # arrow, parquet, csv, etc.
def add_field(self, field_schema: FieldSchema):
"""Add a field to the schema."""
self.fields[field_schema.name] = field_schema
@property
def field_names(self) -> List[str]:
return list(self.fields.keys())
@property
def num_fields(self) -> int:
return len(self.fields)
def to_dict(self) -> Dict[str, Any]:
return {
"fields": {k: v.to_dict() for k, v in self.fields.items()},
"primary_key": self.primary_key,
"foreign_keys": self.foreign_keys,
"source_format": self.source_format,
}
def hash(self) -> str:
"""Compute schema hash - identifies structure regardless of content."""
# Sort fields for deterministic hashing
ordered_fields = sorted(self.fields.keys())
content = json.dumps({
"fields": [self.fields[k].to_dict() for k in ordered_fields],
"primary_key": self.primary_key,
}, sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()
def diff(self, other: "DatasetSchema") -> Dict[str, Any]:
"""Compare two schemas and return differences."""
added = set(other.field_names) - set(self.field_names)
removed = set(self.field_names) - set(other.field_names)
modified = {}
for name in set(self.field_names) & set(other.field_names):
if self.fields[name].hash() != other.fields[name].hash():
modified[name] = {
"old": self.fields[name].to_dict(),
"new": other.fields[name].to_dict(),
}
return {
"added": list(added),
"removed": list(removed),
"modified": modified,
"compatible": len(removed) == 0 and len(modified) == 0,
}
class SchemaObserver:
"""
Observes and extracts schemas from various data sources.
"""
# Type mapping from various sources to normalized types
TYPE_MAP = {
# Python types
"str": "string",
"int": "int64",
"float": "float64",
"bool": "bool",
"bytes": "binary",
# NumPy types
"int8": "int8",
"int16": "int16",
"int32": "int32",
"int64": "int64",
"uint8": "uint8",
"uint16": "uint16",
"uint32": "uint32",
"uint64": "uint64",
"float16": "float16",
"float32": "float32",
"float64": "float64",
# Arrow types
"string": "string",
"large_string": "string",
"binary": "binary",
"large_binary": "binary",
# HuggingFace special types
"Image": "image",
"Audio": "audio",
"ClassLabel": "categorical",
}
def observe_hf_dataset(self, dataset) -> DatasetSchema:
"""
Extract schema from HuggingFace Dataset.
Args:
dataset: A HuggingFace datasets.Dataset or DatasetDict
Returns:
DatasetSchema with all fields
"""
schema = DatasetSchema(source_format="arrow")
# Get features (works for both Dataset and DatasetDict)
if hasattr(dataset, 'features'):
features = dataset.features
elif hasattr(dataset, '__iter__'):
# DatasetDict - get features from first split
first_split = next(iter(dataset.values()))
features = first_split.features
else:
raise ValueError(f"Cannot extract features from {type(dataset)}")
# Parse each feature
for name, feature in features.items():
field_schema = self._parse_hf_feature(name, feature)
schema.add_field(field_schema)
return schema
def _parse_hf_feature(self, name: str, feature) -> FieldSchema:
"""Parse a HuggingFace Feature into FieldSchema."""
# Import here to avoid hard dependency
try:
from datasets import (
Value, ClassLabel, Sequence,
Array2D, Array3D, Array4D, Array5D,
Image, Audio
)
except ImportError:
# Fallback for when datasets not installed
return FieldSchema(name=name, dtype="unknown")
# Value type (primitives)
if isinstance(feature, Value):
return FieldSchema(
name=name,
dtype=self.TYPE_MAP.get(feature.dtype, feature.dtype),
)
# ClassLabel (categorical)
if isinstance(feature, ClassLabel):
return FieldSchema(
name=name,
dtype="categorical",
is_categorical=True,
categories=feature.names,
num_categories=feature.num_classes,
)
# Sequence (list)
if isinstance(feature, Sequence):
inner = self._parse_hf_feature(f"{name}_inner", feature.feature)
return FieldSchema(
name=name,
dtype="list",
is_list=True,
list_inner_type=inner.dtype,
)
# Arrays
if isinstance(feature, (Array2D, Array3D, Array4D, Array5D)):
return FieldSchema(
name=name,
dtype=self.TYPE_MAP.get(feature.dtype, feature.dtype),
shape=feature.shape,
)
# Image
if isinstance(feature, Image):
return FieldSchema(
name=name,
dtype="image",
)
# Audio
if isinstance(feature, Audio):
return FieldSchema(
name=name,
dtype="audio",
)
# Dict/nested structure
if isinstance(feature, dict):
nested = {}
for k, v in feature.items():
nested[k] = self._parse_hf_feature(k, v)
return FieldSchema(
name=name,
dtype="struct",
nested_fields=nested,
)
# Fallback
return FieldSchema(
name=name,
dtype=str(type(feature).__name__),
)
def observe_pandas(self, df) -> DatasetSchema:
"""
Extract schema from Pandas DataFrame.
Args:
df: A pandas DataFrame
Returns:
DatasetSchema with all fields
"""
schema = DatasetSchema(source_format="pandas")
for col in df.columns:
dtype = str(df[col].dtype)
normalized = self.TYPE_MAP.get(dtype, dtype)
# Check for categorical
if dtype == "category":
schema.add_field(FieldSchema(
name=col,
dtype="categorical",
is_categorical=True,
categories=list(df[col].cat.categories),
num_categories=len(df[col].cat.categories),
))
else:
schema.add_field(FieldSchema(
name=col,
dtype=normalized,
nullable=df[col].isna().any(),
))
return schema
def observe_dict(self, data: Dict[str, Any], sample_size: int = 100) -> DatasetSchema:
"""
Extract schema from a dict of lists (columnar format).
Args:
data: Dict mapping column names to lists of values
sample_size: Number of values to sample for type inference
Returns:
DatasetSchema with all fields
"""
schema = DatasetSchema(source_format="dict")
for col, values in data.items():
if not values:
schema.add_field(FieldSchema(name=col, dtype="unknown"))
continue
# Sample values for type inference
sample = values[:sample_size]
types = set(type(v).__name__ for v in sample if v is not None)
# Determine type
if len(types) == 0:
dtype = "null"
elif len(types) == 1:
dtype = self.TYPE_MAP.get(types.pop(), "unknown")
else:
dtype = "mixed"
# Check for nulls
nullable = any(v is None for v in sample)
schema.add_field(FieldSchema(
name=col,
dtype=dtype,
nullable=nullable,
))
return schema
def observe_arrow(self, table) -> DatasetSchema:
"""
Extract schema from PyArrow Table.
Args:
table: A pyarrow.Table
Returns:
DatasetSchema with all fields
"""
schema = DatasetSchema(source_format="arrow")
for field in table.schema:
dtype = str(field.type)
normalized = self.TYPE_MAP.get(dtype, dtype)
schema.add_field(FieldSchema(
name=field.name,
dtype=normalized,
nullable=field.nullable,
))
return schema
def hash_content(data, sample_size: int = 10000) -> str:
"""
Compute content hash of dataset.
For large datasets, samples rows for efficiency.
"""
hasher = hashlib.sha256()
# Handle dict first (dict also has __iter__ and __len__)
if isinstance(data, dict):
content = json.dumps(data, sort_keys=True, default=str)
hasher.update(content.encode())
# Handle list
elif isinstance(data, list):
for item in data[:sample_size]:
item_str = json.dumps(item, sort_keys=True, default=str)
hasher.update(item_str.encode())
# Handle HuggingFace Dataset or other iterables with __len__
elif hasattr(data, '__iter__') and hasattr(data, '__len__'):
# Sample if large
n = len(data)
if n > sample_size:
import random
indices = sorted(random.sample(range(n), sample_size))
sample = [data[i] for i in indices]
else:
sample = list(data)
for row in sample:
row_str = json.dumps(row, sort_keys=True, default=str)
hasher.update(row_str.encode())
return hasher.hexdigest()