"""Profiling tools for analyzing dataset statistics and quality.""" from typing import Optional, Dict, Any, List from utils.hf_client import get_client from utils.formatting import format_statistics, format_quality_report import json def get_statistics( dataset_id: str, config: Optional[str] = None, split: str = "train", sample_size: int = 1000 ) -> str: """ Compute basic statistics for each column in a dataset. Use this tool to get a statistical overview of a dataset, including counts, means, unique values, and distributions for each column. Args: dataset_id: The full dataset identifier (e.g., "squad", "imdb") config: Optional dataset configuration name. Leave empty for default. split: The dataset split to analyze ("train", "test", "validation"). Default: "train" sample_size: Number of rows to sample for statistics (100-5000, default: 1000). Larger samples are more accurate but slower. Returns: Formatted statistics including: - Total row count (estimated from sample) - Per-column statistics: - Numeric: min, max, mean, median, std - Text: avg length, min/max length, unique count - Categorical: value counts, top categories Notes: - Statistics are computed on a sample for efficiency - Very large datasets may show approximate values - Binary data columns (images, audio) show type info only """ sample_size = max(100, min(5000, sample_size)) client = get_client() # Load sample for statistics samples = client.load_sample( dataset_id=dataset_id, config=config, split=split, n_rows=sample_size ) if not samples or "error" in samples[0]: error_msg = samples[0].get('error', 'Unknown error') if samples else 'No data' return f"Error loading data for statistics: {error_msg}" # Compute statistics stats = { "total_rows": f"~{len(samples):,}+ (sampled)", "column_stats": {} } # Analyze each column columns = samples[0].keys() if samples else [] for col in columns: col_values = [row.get(col) for row in samples if row.get(col) is not None] if not col_values: stats["column_stats"][col] = {"status": "all null"} continue # Determine column type and compute appropriate stats sample_val = col_values[0] if isinstance(sample_val, (int, float)) and not isinstance(sample_val, bool): # Numeric column numeric_vals = [v for v in col_values if isinstance(v, (int, float))] if numeric_vals: stats["column_stats"][col] = { "type": "numeric", "count": len(numeric_vals), "min": min(numeric_vals), "max": max(numeric_vals), "mean": sum(numeric_vals) / len(numeric_vals), "unique": len(set(numeric_vals)) } elif isinstance(sample_val, str): # Text column lengths = [len(v) for v in col_values if isinstance(v, str)] unique_vals = set(col_values) stats["column_stats"][col] = { "type": "text", "count": len(col_values), "avg_length": sum(lengths) / len(lengths) if lengths else 0, "min_length": min(lengths) if lengths else 0, "max_length": max(lengths) if lengths else 0, "unique": len(unique_vals), "sample_values": list(unique_vals)[:3] } elif isinstance(sample_val, bool): # Boolean column true_count = sum(1 for v in col_values if v is True) stats["column_stats"][col] = { "type": "boolean", "count": len(col_values), "true_count": true_count, "false_count": len(col_values) - true_count, "true_pct": (true_count / len(col_values)) * 100 } elif isinstance(sample_val, list): # List/sequence column lengths = [len(v) for v in col_values if isinstance(v, list)] stats["column_stats"][col] = { "type": "list/sequence", "count": len(col_values), "avg_length": sum(lengths) / len(lengths) if lengths else 0, "min_length": min(lengths) if lengths else 0, "max_length": max(lengths) if lengths else 0 } elif isinstance(sample_val, dict): # Nested object stats["column_stats"][col] = { "type": "object/nested", "count": len(col_values), "sample_keys": list(sample_val.keys())[:5] if sample_val else [] } else: # Binary or other type stats["column_stats"][col] = { "type": str(type(sample_val).__name__), "count": len(col_values), "note": "Binary/special data type" } return format_statistics(stats) def profile_quality( dataset_id: str, config: Optional[str] = None, split: str = "train", sample_size: int = 500 ) -> str: """ Assess the data quality of a dataset and identify potential issues. Use this tool to check for common data quality problems like missing values, duplicates, imbalanced classes, and outliers before using a dataset. Args: dataset_id: The full dataset identifier (e.g., "squad", "imdb") config: Optional dataset configuration name. Leave empty for default. split: The dataset split to analyze ("train", "test", "validation"). Default: "train" sample_size: Number of rows to sample for quality check (100-2000, default: 500). Returns: Data quality report including: - Overall quality score (0-100) - List of identified issues - Per-column quality metrics: - Missing value percentage - Unique value percentage - Specific issues (constant values, high cardinality, etc.) Quality checks performed: - Missing/null values - Duplicate rows - Constant columns (single value) - High cardinality text columns - Class imbalance for categorical columns - Outliers for numeric columns """ sample_size = max(100, min(2000, sample_size)) client = get_client() # Load sample samples = client.load_sample( dataset_id=dataset_id, config=config, split=split, n_rows=sample_size ) if not samples or "error" in samples[0]: error_msg = samples[0].get('error', 'Unknown error') if samples else 'No data' return format_quality_report({"error": error_msg}) # Initialize report report: Dict[str, Any] = { "dataset_id": dataset_id, "sample_size": len(samples), "issues": [], "column_quality": {}, "overall_score": 100 } # Check for duplicate rows try: row_strings = [json.dumps(row, sort_keys=True, default=str) for row in samples] unique_rows = len(set(row_strings)) duplicate_pct = ((len(samples) - unique_rows) / len(samples)) * 100 if duplicate_pct > 5: report["issues"].append(f"High duplicate rate: {duplicate_pct:.1f}% duplicate rows") report["overall_score"] -= min(20, duplicate_pct) except Exception: pass # Analyze each column columns = samples[0].keys() if samples else [] for col in columns: col_values = [row.get(col) for row in samples] non_null_values = [v for v in col_values if v is not None and v != ""] col_quality: Dict[str, Any] = { "missing_pct": ((len(samples) - len(non_null_values)) / len(samples)) * 100, "issues": [] } # Check missing values if col_quality["missing_pct"] > 20: col_quality["issues"].append("High missing rate") report["overall_score"] -= 5 elif col_quality["missing_pct"] > 50: report["issues"].append(f"Column '{col}' has {col_quality['missing_pct']:.0f}% missing values") report["overall_score"] -= 10 # Calculate unique percentage if non_null_values: unique_count = len(set(str(v) for v in non_null_values)) col_quality["unique_pct"] = (unique_count / len(non_null_values)) * 100 # Check for constant column if unique_count == 1: col_quality["issues"].append("Constant value") report["issues"].append(f"Column '{col}' has only one unique value") report["overall_score"] -= 5 # Check for potential ID column (all unique) elif col_quality["unique_pct"] > 99 and len(non_null_values) > 10: col_quality["issues"].append("Possibly ID column") # Check for high cardinality in small dataset elif isinstance(non_null_values[0], str) and unique_count > len(samples) * 0.8: col_quality["issues"].append("High cardinality") # Check class imbalance for categorical sample_val = non_null_values[0] if isinstance(sample_val, (str, int, bool)) and unique_count <= 20: value_counts = {} for v in non_null_values: key = str(v) value_counts[key] = value_counts.get(key, 0) + 1 if value_counts: max_count = max(value_counts.values()) min_count = min(value_counts.values()) if max_count > min_count * 10: col_quality["issues"].append("Class imbalance") if "label" in col.lower() or "class" in col.lower(): report["issues"].append(f"Significant class imbalance in '{col}'") report["overall_score"] -= 10 else: col_quality["unique_pct"] = 0 col_quality["issues"] = ", ".join(col_quality["issues"]) if col_quality["issues"] else "-" report["column_quality"][col] = col_quality # Clamp score report["overall_score"] = max(0, min(100, report["overall_score"])) if not report["issues"]: report["issues"].append("No major issues detected") return format_quality_report(report)