Spaces:
Running
Running
| """Profiling tools for analyzing dataset statistics and quality.""" | |
| from typing import Optional, Dict, Any, List | |
| from utils.hf_client import get_client | |
| from utils.formatting import format_statistics, format_quality_report | |
| import json | |
| def get_statistics( | |
| dataset_id: str, | |
| config: Optional[str] = None, | |
| split: str = "train", | |
| sample_size: int = 1000 | |
| ) -> str: | |
| """ | |
| Compute basic statistics for each column in a dataset. | |
| Use this tool to get a statistical overview of a dataset, including | |
| counts, means, unique values, and distributions for each column. | |
| Args: | |
| dataset_id: The full dataset identifier (e.g., "squad", "imdb") | |
| config: Optional dataset configuration name. Leave empty for default. | |
| split: The dataset split to analyze ("train", "test", "validation"). Default: "train" | |
| sample_size: Number of rows to sample for statistics (100-5000, default: 1000). | |
| Larger samples are more accurate but slower. | |
| Returns: | |
| Formatted statistics including: | |
| - Total row count (estimated from sample) | |
| - Per-column statistics: | |
| - Numeric: min, max, mean, median, std | |
| - Text: avg length, min/max length, unique count | |
| - Categorical: value counts, top categories | |
| Notes: | |
| - Statistics are computed on a sample for efficiency | |
| - Very large datasets may show approximate values | |
| - Binary data columns (images, audio) show type info only | |
| """ | |
| sample_size = max(100, min(5000, sample_size)) | |
| client = get_client() | |
| # Load sample for statistics | |
| samples = client.load_sample( | |
| dataset_id=dataset_id, | |
| config=config, | |
| split=split, | |
| n_rows=sample_size | |
| ) | |
| if not samples or "error" in samples[0]: | |
| error_msg = samples[0].get('error', 'Unknown error') if samples else 'No data' | |
| return f"Error loading data for statistics: {error_msg}" | |
| # Compute statistics | |
| stats = { | |
| "total_rows": f"~{len(samples):,}+ (sampled)", | |
| "column_stats": {} | |
| } | |
| # Analyze each column | |
| columns = samples[0].keys() if samples else [] | |
| for col in columns: | |
| col_values = [row.get(col) for row in samples if row.get(col) is not None] | |
| if not col_values: | |
| stats["column_stats"][col] = {"status": "all null"} | |
| continue | |
| # Determine column type and compute appropriate stats | |
| sample_val = col_values[0] | |
| if isinstance(sample_val, (int, float)) and not isinstance(sample_val, bool): | |
| # Numeric column | |
| numeric_vals = [v for v in col_values if isinstance(v, (int, float))] | |
| if numeric_vals: | |
| stats["column_stats"][col] = { | |
| "type": "numeric", | |
| "count": len(numeric_vals), | |
| "min": min(numeric_vals), | |
| "max": max(numeric_vals), | |
| "mean": sum(numeric_vals) / len(numeric_vals), | |
| "unique": len(set(numeric_vals)) | |
| } | |
| elif isinstance(sample_val, str): | |
| # Text column | |
| lengths = [len(v) for v in col_values if isinstance(v, str)] | |
| unique_vals = set(col_values) | |
| stats["column_stats"][col] = { | |
| "type": "text", | |
| "count": len(col_values), | |
| "avg_length": sum(lengths) / len(lengths) if lengths else 0, | |
| "min_length": min(lengths) if lengths else 0, | |
| "max_length": max(lengths) if lengths else 0, | |
| "unique": len(unique_vals), | |
| "sample_values": list(unique_vals)[:3] | |
| } | |
| elif isinstance(sample_val, bool): | |
| # Boolean column | |
| true_count = sum(1 for v in col_values if v is True) | |
| stats["column_stats"][col] = { | |
| "type": "boolean", | |
| "count": len(col_values), | |
| "true_count": true_count, | |
| "false_count": len(col_values) - true_count, | |
| "true_pct": (true_count / len(col_values)) * 100 | |
| } | |
| elif isinstance(sample_val, list): | |
| # List/sequence column | |
| lengths = [len(v) for v in col_values if isinstance(v, list)] | |
| stats["column_stats"][col] = { | |
| "type": "list/sequence", | |
| "count": len(col_values), | |
| "avg_length": sum(lengths) / len(lengths) if lengths else 0, | |
| "min_length": min(lengths) if lengths else 0, | |
| "max_length": max(lengths) if lengths else 0 | |
| } | |
| elif isinstance(sample_val, dict): | |
| # Nested object | |
| stats["column_stats"][col] = { | |
| "type": "object/nested", | |
| "count": len(col_values), | |
| "sample_keys": list(sample_val.keys())[:5] if sample_val else [] | |
| } | |
| else: | |
| # Binary or other type | |
| stats["column_stats"][col] = { | |
| "type": str(type(sample_val).__name__), | |
| "count": len(col_values), | |
| "note": "Binary/special data type" | |
| } | |
| return format_statistics(stats) | |
| def profile_quality( | |
| dataset_id: str, | |
| config: Optional[str] = None, | |
| split: str = "train", | |
| sample_size: int = 500 | |
| ) -> str: | |
| """ | |
| Assess the data quality of a dataset and identify potential issues. | |
| Use this tool to check for common data quality problems like missing values, | |
| duplicates, imbalanced classes, and outliers before using a dataset. | |
| Args: | |
| dataset_id: The full dataset identifier (e.g., "squad", "imdb") | |
| config: Optional dataset configuration name. Leave empty for default. | |
| split: The dataset split to analyze ("train", "test", "validation"). Default: "train" | |
| sample_size: Number of rows to sample for quality check (100-2000, default: 500). | |
| Returns: | |
| Data quality report including: | |
| - Overall quality score (0-100) | |
| - List of identified issues | |
| - Per-column quality metrics: | |
| - Missing value percentage | |
| - Unique value percentage | |
| - Specific issues (constant values, high cardinality, etc.) | |
| Quality checks performed: | |
| - Missing/null values | |
| - Duplicate rows | |
| - Constant columns (single value) | |
| - High cardinality text columns | |
| - Class imbalance for categorical columns | |
| - Outliers for numeric columns | |
| """ | |
| sample_size = max(100, min(2000, sample_size)) | |
| client = get_client() | |
| # Load sample | |
| samples = client.load_sample( | |
| dataset_id=dataset_id, | |
| config=config, | |
| split=split, | |
| n_rows=sample_size | |
| ) | |
| if not samples or "error" in samples[0]: | |
| error_msg = samples[0].get('error', 'Unknown error') if samples else 'No data' | |
| return format_quality_report({"error": error_msg}) | |
| # Initialize report | |
| report: Dict[str, Any] = { | |
| "dataset_id": dataset_id, | |
| "sample_size": len(samples), | |
| "issues": [], | |
| "column_quality": {}, | |
| "overall_score": 100 | |
| } | |
| # Check for duplicate rows | |
| try: | |
| row_strings = [json.dumps(row, sort_keys=True, default=str) for row in samples] | |
| unique_rows = len(set(row_strings)) | |
| duplicate_pct = ((len(samples) - unique_rows) / len(samples)) * 100 | |
| if duplicate_pct > 5: | |
| report["issues"].append(f"High duplicate rate: {duplicate_pct:.1f}% duplicate rows") | |
| report["overall_score"] -= min(20, duplicate_pct) | |
| except Exception: | |
| pass | |
| # Analyze each column | |
| columns = samples[0].keys() if samples else [] | |
| for col in columns: | |
| col_values = [row.get(col) for row in samples] | |
| non_null_values = [v for v in col_values if v is not None and v != ""] | |
| col_quality: Dict[str, Any] = { | |
| "missing_pct": ((len(samples) - len(non_null_values)) / len(samples)) * 100, | |
| "issues": [] | |
| } | |
| # Check missing values | |
| if col_quality["missing_pct"] > 20: | |
| col_quality["issues"].append("High missing rate") | |
| report["overall_score"] -= 5 | |
| elif col_quality["missing_pct"] > 50: | |
| report["issues"].append(f"Column '{col}' has {col_quality['missing_pct']:.0f}% missing values") | |
| report["overall_score"] -= 10 | |
| # Calculate unique percentage | |
| if non_null_values: | |
| unique_count = len(set(str(v) for v in non_null_values)) | |
| col_quality["unique_pct"] = (unique_count / len(non_null_values)) * 100 | |
| # Check for constant column | |
| if unique_count == 1: | |
| col_quality["issues"].append("Constant value") | |
| report["issues"].append(f"Column '{col}' has only one unique value") | |
| report["overall_score"] -= 5 | |
| # Check for potential ID column (all unique) | |
| elif col_quality["unique_pct"] > 99 and len(non_null_values) > 10: | |
| col_quality["issues"].append("Possibly ID column") | |
| # Check for high cardinality in small dataset | |
| elif isinstance(non_null_values[0], str) and unique_count > len(samples) * 0.8: | |
| col_quality["issues"].append("High cardinality") | |
| # Check class imbalance for categorical | |
| sample_val = non_null_values[0] | |
| if isinstance(sample_val, (str, int, bool)) and unique_count <= 20: | |
| value_counts = {} | |
| for v in non_null_values: | |
| key = str(v) | |
| value_counts[key] = value_counts.get(key, 0) + 1 | |
| if value_counts: | |
| max_count = max(value_counts.values()) | |
| min_count = min(value_counts.values()) | |
| if max_count > min_count * 10: | |
| col_quality["issues"].append("Class imbalance") | |
| if "label" in col.lower() or "class" in col.lower(): | |
| report["issues"].append(f"Significant class imbalance in '{col}'") | |
| report["overall_score"] -= 10 | |
| else: | |
| col_quality["unique_pct"] = 0 | |
| col_quality["issues"] = ", ".join(col_quality["issues"]) if col_quality["issues"] else "-" | |
| report["column_quality"][col] = col_quality | |
| # Clamp score | |
| report["overall_score"] = max(0, min(100, report["overall_score"])) | |
| if not report["issues"]: | |
| report["issues"].append("No major issues detected") | |
| return format_quality_report(report) | |