Spaces:
Running
Running
File size: 10,590 Bytes
b67578f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
"""Profiling tools for analyzing dataset statistics and quality."""
from typing import Optional, Dict, Any, List
from utils.hf_client import get_client
from utils.formatting import format_statistics, format_quality_report
import json
def get_statistics(
dataset_id: str,
config: Optional[str] = None,
split: str = "train",
sample_size: int = 1000
) -> str:
"""
Compute basic statistics for each column in a dataset.
Use this tool to get a statistical overview of a dataset, including
counts, means, unique values, and distributions for each column.
Args:
dataset_id: The full dataset identifier (e.g., "squad", "imdb")
config: Optional dataset configuration name. Leave empty for default.
split: The dataset split to analyze ("train", "test", "validation"). Default: "train"
sample_size: Number of rows to sample for statistics (100-5000, default: 1000).
Larger samples are more accurate but slower.
Returns:
Formatted statistics including:
- Total row count (estimated from sample)
- Per-column statistics:
- Numeric: min, max, mean, median, std
- Text: avg length, min/max length, unique count
- Categorical: value counts, top categories
Notes:
- Statistics are computed on a sample for efficiency
- Very large datasets may show approximate values
- Binary data columns (images, audio) show type info only
"""
sample_size = max(100, min(5000, sample_size))
client = get_client()
# Load sample for statistics
samples = client.load_sample(
dataset_id=dataset_id,
config=config,
split=split,
n_rows=sample_size
)
if not samples or "error" in samples[0]:
error_msg = samples[0].get('error', 'Unknown error') if samples else 'No data'
return f"Error loading data for statistics: {error_msg}"
# Compute statistics
stats = {
"total_rows": f"~{len(samples):,}+ (sampled)",
"column_stats": {}
}
# Analyze each column
columns = samples[0].keys() if samples else []
for col in columns:
col_values = [row.get(col) for row in samples if row.get(col) is not None]
if not col_values:
stats["column_stats"][col] = {"status": "all null"}
continue
# Determine column type and compute appropriate stats
sample_val = col_values[0]
if isinstance(sample_val, (int, float)) and not isinstance(sample_val, bool):
# Numeric column
numeric_vals = [v for v in col_values if isinstance(v, (int, float))]
if numeric_vals:
stats["column_stats"][col] = {
"type": "numeric",
"count": len(numeric_vals),
"min": min(numeric_vals),
"max": max(numeric_vals),
"mean": sum(numeric_vals) / len(numeric_vals),
"unique": len(set(numeric_vals))
}
elif isinstance(sample_val, str):
# Text column
lengths = [len(v) for v in col_values if isinstance(v, str)]
unique_vals = set(col_values)
stats["column_stats"][col] = {
"type": "text",
"count": len(col_values),
"avg_length": sum(lengths) / len(lengths) if lengths else 0,
"min_length": min(lengths) if lengths else 0,
"max_length": max(lengths) if lengths else 0,
"unique": len(unique_vals),
"sample_values": list(unique_vals)[:3]
}
elif isinstance(sample_val, bool):
# Boolean column
true_count = sum(1 for v in col_values if v is True)
stats["column_stats"][col] = {
"type": "boolean",
"count": len(col_values),
"true_count": true_count,
"false_count": len(col_values) - true_count,
"true_pct": (true_count / len(col_values)) * 100
}
elif isinstance(sample_val, list):
# List/sequence column
lengths = [len(v) for v in col_values if isinstance(v, list)]
stats["column_stats"][col] = {
"type": "list/sequence",
"count": len(col_values),
"avg_length": sum(lengths) / len(lengths) if lengths else 0,
"min_length": min(lengths) if lengths else 0,
"max_length": max(lengths) if lengths else 0
}
elif isinstance(sample_val, dict):
# Nested object
stats["column_stats"][col] = {
"type": "object/nested",
"count": len(col_values),
"sample_keys": list(sample_val.keys())[:5] if sample_val else []
}
else:
# Binary or other type
stats["column_stats"][col] = {
"type": str(type(sample_val).__name__),
"count": len(col_values),
"note": "Binary/special data type"
}
return format_statistics(stats)
def profile_quality(
dataset_id: str,
config: Optional[str] = None,
split: str = "train",
sample_size: int = 500
) -> str:
"""
Assess the data quality of a dataset and identify potential issues.
Use this tool to check for common data quality problems like missing values,
duplicates, imbalanced classes, and outliers before using a dataset.
Args:
dataset_id: The full dataset identifier (e.g., "squad", "imdb")
config: Optional dataset configuration name. Leave empty for default.
split: The dataset split to analyze ("train", "test", "validation"). Default: "train"
sample_size: Number of rows to sample for quality check (100-2000, default: 500).
Returns:
Data quality report including:
- Overall quality score (0-100)
- List of identified issues
- Per-column quality metrics:
- Missing value percentage
- Unique value percentage
- Specific issues (constant values, high cardinality, etc.)
Quality checks performed:
- Missing/null values
- Duplicate rows
- Constant columns (single value)
- High cardinality text columns
- Class imbalance for categorical columns
- Outliers for numeric columns
"""
sample_size = max(100, min(2000, sample_size))
client = get_client()
# Load sample
samples = client.load_sample(
dataset_id=dataset_id,
config=config,
split=split,
n_rows=sample_size
)
if not samples or "error" in samples[0]:
error_msg = samples[0].get('error', 'Unknown error') if samples else 'No data'
return format_quality_report({"error": error_msg})
# Initialize report
report: Dict[str, Any] = {
"dataset_id": dataset_id,
"sample_size": len(samples),
"issues": [],
"column_quality": {},
"overall_score": 100
}
# Check for duplicate rows
try:
row_strings = [json.dumps(row, sort_keys=True, default=str) for row in samples]
unique_rows = len(set(row_strings))
duplicate_pct = ((len(samples) - unique_rows) / len(samples)) * 100
if duplicate_pct > 5:
report["issues"].append(f"High duplicate rate: {duplicate_pct:.1f}% duplicate rows")
report["overall_score"] -= min(20, duplicate_pct)
except Exception:
pass
# Analyze each column
columns = samples[0].keys() if samples else []
for col in columns:
col_values = [row.get(col) for row in samples]
non_null_values = [v for v in col_values if v is not None and v != ""]
col_quality: Dict[str, Any] = {
"missing_pct": ((len(samples) - len(non_null_values)) / len(samples)) * 100,
"issues": []
}
# Check missing values
if col_quality["missing_pct"] > 20:
col_quality["issues"].append("High missing rate")
report["overall_score"] -= 5
elif col_quality["missing_pct"] > 50:
report["issues"].append(f"Column '{col}' has {col_quality['missing_pct']:.0f}% missing values")
report["overall_score"] -= 10
# Calculate unique percentage
if non_null_values:
unique_count = len(set(str(v) for v in non_null_values))
col_quality["unique_pct"] = (unique_count / len(non_null_values)) * 100
# Check for constant column
if unique_count == 1:
col_quality["issues"].append("Constant value")
report["issues"].append(f"Column '{col}' has only one unique value")
report["overall_score"] -= 5
# Check for potential ID column (all unique)
elif col_quality["unique_pct"] > 99 and len(non_null_values) > 10:
col_quality["issues"].append("Possibly ID column")
# Check for high cardinality in small dataset
elif isinstance(non_null_values[0], str) and unique_count > len(samples) * 0.8:
col_quality["issues"].append("High cardinality")
# Check class imbalance for categorical
sample_val = non_null_values[0]
if isinstance(sample_val, (str, int, bool)) and unique_count <= 20:
value_counts = {}
for v in non_null_values:
key = str(v)
value_counts[key] = value_counts.get(key, 0) + 1
if value_counts:
max_count = max(value_counts.values())
min_count = min(value_counts.values())
if max_count > min_count * 10:
col_quality["issues"].append("Class imbalance")
if "label" in col.lower() or "class" in col.lower():
report["issues"].append(f"Significant class imbalance in '{col}'")
report["overall_score"] -= 10
else:
col_quality["unique_pct"] = 0
col_quality["issues"] = ", ".join(col_quality["issues"]) if col_quality["issues"] else "-"
report["column_quality"][col] = col_quality
# Clamp score
report["overall_score"] = max(0, min(100, report["overall_score"]))
if not report["issues"]:
report["issues"].append("No major issues detected")
return format_quality_report(report)
|