dataview-mcp / tools /profiling.py
efecelik's picture
Initial release: DataView MCP - HuggingFace Dataset Explorer
b67578f
"""Profiling tools for analyzing dataset statistics and quality."""
from typing import Optional, Dict, Any, List
from utils.hf_client import get_client
from utils.formatting import format_statistics, format_quality_report
import json
def get_statistics(
dataset_id: str,
config: Optional[str] = None,
split: str = "train",
sample_size: int = 1000
) -> str:
"""
Compute basic statistics for each column in a dataset.
Use this tool to get a statistical overview of a dataset, including
counts, means, unique values, and distributions for each column.
Args:
dataset_id: The full dataset identifier (e.g., "squad", "imdb")
config: Optional dataset configuration name. Leave empty for default.
split: The dataset split to analyze ("train", "test", "validation"). Default: "train"
sample_size: Number of rows to sample for statistics (100-5000, default: 1000).
Larger samples are more accurate but slower.
Returns:
Formatted statistics including:
- Total row count (estimated from sample)
- Per-column statistics:
- Numeric: min, max, mean, median, std
- Text: avg length, min/max length, unique count
- Categorical: value counts, top categories
Notes:
- Statistics are computed on a sample for efficiency
- Very large datasets may show approximate values
- Binary data columns (images, audio) show type info only
"""
sample_size = max(100, min(5000, sample_size))
client = get_client()
# Load sample for statistics
samples = client.load_sample(
dataset_id=dataset_id,
config=config,
split=split,
n_rows=sample_size
)
if not samples or "error" in samples[0]:
error_msg = samples[0].get('error', 'Unknown error') if samples else 'No data'
return f"Error loading data for statistics: {error_msg}"
# Compute statistics
stats = {
"total_rows": f"~{len(samples):,}+ (sampled)",
"column_stats": {}
}
# Analyze each column
columns = samples[0].keys() if samples else []
for col in columns:
col_values = [row.get(col) for row in samples if row.get(col) is not None]
if not col_values:
stats["column_stats"][col] = {"status": "all null"}
continue
# Determine column type and compute appropriate stats
sample_val = col_values[0]
if isinstance(sample_val, (int, float)) and not isinstance(sample_val, bool):
# Numeric column
numeric_vals = [v for v in col_values if isinstance(v, (int, float))]
if numeric_vals:
stats["column_stats"][col] = {
"type": "numeric",
"count": len(numeric_vals),
"min": min(numeric_vals),
"max": max(numeric_vals),
"mean": sum(numeric_vals) / len(numeric_vals),
"unique": len(set(numeric_vals))
}
elif isinstance(sample_val, str):
# Text column
lengths = [len(v) for v in col_values if isinstance(v, str)]
unique_vals = set(col_values)
stats["column_stats"][col] = {
"type": "text",
"count": len(col_values),
"avg_length": sum(lengths) / len(lengths) if lengths else 0,
"min_length": min(lengths) if lengths else 0,
"max_length": max(lengths) if lengths else 0,
"unique": len(unique_vals),
"sample_values": list(unique_vals)[:3]
}
elif isinstance(sample_val, bool):
# Boolean column
true_count = sum(1 for v in col_values if v is True)
stats["column_stats"][col] = {
"type": "boolean",
"count": len(col_values),
"true_count": true_count,
"false_count": len(col_values) - true_count,
"true_pct": (true_count / len(col_values)) * 100
}
elif isinstance(sample_val, list):
# List/sequence column
lengths = [len(v) for v in col_values if isinstance(v, list)]
stats["column_stats"][col] = {
"type": "list/sequence",
"count": len(col_values),
"avg_length": sum(lengths) / len(lengths) if lengths else 0,
"min_length": min(lengths) if lengths else 0,
"max_length": max(lengths) if lengths else 0
}
elif isinstance(sample_val, dict):
# Nested object
stats["column_stats"][col] = {
"type": "object/nested",
"count": len(col_values),
"sample_keys": list(sample_val.keys())[:5] if sample_val else []
}
else:
# Binary or other type
stats["column_stats"][col] = {
"type": str(type(sample_val).__name__),
"count": len(col_values),
"note": "Binary/special data type"
}
return format_statistics(stats)
def profile_quality(
dataset_id: str,
config: Optional[str] = None,
split: str = "train",
sample_size: int = 500
) -> str:
"""
Assess the data quality of a dataset and identify potential issues.
Use this tool to check for common data quality problems like missing values,
duplicates, imbalanced classes, and outliers before using a dataset.
Args:
dataset_id: The full dataset identifier (e.g., "squad", "imdb")
config: Optional dataset configuration name. Leave empty for default.
split: The dataset split to analyze ("train", "test", "validation"). Default: "train"
sample_size: Number of rows to sample for quality check (100-2000, default: 500).
Returns:
Data quality report including:
- Overall quality score (0-100)
- List of identified issues
- Per-column quality metrics:
- Missing value percentage
- Unique value percentage
- Specific issues (constant values, high cardinality, etc.)
Quality checks performed:
- Missing/null values
- Duplicate rows
- Constant columns (single value)
- High cardinality text columns
- Class imbalance for categorical columns
- Outliers for numeric columns
"""
sample_size = max(100, min(2000, sample_size))
client = get_client()
# Load sample
samples = client.load_sample(
dataset_id=dataset_id,
config=config,
split=split,
n_rows=sample_size
)
if not samples or "error" in samples[0]:
error_msg = samples[0].get('error', 'Unknown error') if samples else 'No data'
return format_quality_report({"error": error_msg})
# Initialize report
report: Dict[str, Any] = {
"dataset_id": dataset_id,
"sample_size": len(samples),
"issues": [],
"column_quality": {},
"overall_score": 100
}
# Check for duplicate rows
try:
row_strings = [json.dumps(row, sort_keys=True, default=str) for row in samples]
unique_rows = len(set(row_strings))
duplicate_pct = ((len(samples) - unique_rows) / len(samples)) * 100
if duplicate_pct > 5:
report["issues"].append(f"High duplicate rate: {duplicate_pct:.1f}% duplicate rows")
report["overall_score"] -= min(20, duplicate_pct)
except Exception:
pass
# Analyze each column
columns = samples[0].keys() if samples else []
for col in columns:
col_values = [row.get(col) for row in samples]
non_null_values = [v for v in col_values if v is not None and v != ""]
col_quality: Dict[str, Any] = {
"missing_pct": ((len(samples) - len(non_null_values)) / len(samples)) * 100,
"issues": []
}
# Check missing values
if col_quality["missing_pct"] > 20:
col_quality["issues"].append("High missing rate")
report["overall_score"] -= 5
elif col_quality["missing_pct"] > 50:
report["issues"].append(f"Column '{col}' has {col_quality['missing_pct']:.0f}% missing values")
report["overall_score"] -= 10
# Calculate unique percentage
if non_null_values:
unique_count = len(set(str(v) for v in non_null_values))
col_quality["unique_pct"] = (unique_count / len(non_null_values)) * 100
# Check for constant column
if unique_count == 1:
col_quality["issues"].append("Constant value")
report["issues"].append(f"Column '{col}' has only one unique value")
report["overall_score"] -= 5
# Check for potential ID column (all unique)
elif col_quality["unique_pct"] > 99 and len(non_null_values) > 10:
col_quality["issues"].append("Possibly ID column")
# Check for high cardinality in small dataset
elif isinstance(non_null_values[0], str) and unique_count > len(samples) * 0.8:
col_quality["issues"].append("High cardinality")
# Check class imbalance for categorical
sample_val = non_null_values[0]
if isinstance(sample_val, (str, int, bool)) and unique_count <= 20:
value_counts = {}
for v in non_null_values:
key = str(v)
value_counts[key] = value_counts.get(key, 0) + 1
if value_counts:
max_count = max(value_counts.values())
min_count = min(value_counts.values())
if max_count > min_count * 10:
col_quality["issues"].append("Class imbalance")
if "label" in col.lower() or "class" in col.lower():
report["issues"].append(f"Significant class imbalance in '{col}'")
report["overall_score"] -= 10
else:
col_quality["unique_pct"] = 0
col_quality["issues"] = ", ".join(col_quality["issues"]) if col_quality["issues"] else "-"
report["column_quality"][col] = col_quality
# Clamp score
report["overall_score"] = max(0, min(100, report["overall_score"]))
if not report["issues"]:
report["issues"].append("No major issues detected")
return format_quality_report(report)