Spaces:
Running
Running
| """Discovery tools for finding similar datasets and suggesting ML tasks.""" | |
| from typing import Optional, List, Dict, Any | |
| from utils.hf_client import get_client | |
| from utils.formatting import format_similar_datasets, format_task_suggestions, format_comparison | |
| # Common ML task patterns based on column names and types | |
| TASK_PATTERNS = { | |
| "text-classification": { | |
| "columns": ["text", "label", "sentence", "review", "comment", "content"], | |
| "name": "Text Classification", | |
| "target_hints": ["label", "class", "category", "sentiment", "target"] | |
| }, | |
| "question-answering": { | |
| "columns": ["question", "answer", "context", "response"], | |
| "name": "Question Answering", | |
| "target_hints": ["answer", "response"] | |
| }, | |
| "summarization": { | |
| "columns": ["article", "summary", "document", "highlights", "abstract"], | |
| "name": "Text Summarization", | |
| "target_hints": ["summary", "highlights", "abstract"] | |
| }, | |
| "translation": { | |
| "columns": ["source", "target", "en", "de", "fr", "es", "translation"], | |
| "name": "Machine Translation", | |
| "target_hints": ["target", "translation"] | |
| }, | |
| "image-classification": { | |
| "columns": ["image", "label", "img", "photo"], | |
| "name": "Image Classification", | |
| "target_hints": ["label", "class", "category"] | |
| }, | |
| "named-entity-recognition": { | |
| "columns": ["tokens", "ner_tags", "tags", "entities"], | |
| "name": "Named Entity Recognition", | |
| "target_hints": ["ner_tags", "tags", "entities", "labels"] | |
| }, | |
| "token-classification": { | |
| "columns": ["tokens", "labels", "tags", "pos_tags"], | |
| "name": "Token Classification", | |
| "target_hints": ["labels", "tags"] | |
| }, | |
| "text-generation": { | |
| "columns": ["prompt", "completion", "input", "output", "instruction"], | |
| "name": "Text Generation / Instruction Following", | |
| "target_hints": ["completion", "output", "response"] | |
| }, | |
| "tabular-classification": { | |
| "columns": ["target", "label", "class"], | |
| "name": "Tabular Classification", | |
| "target_hints": ["target", "label", "class", "y"] | |
| }, | |
| "tabular-regression": { | |
| "columns": ["target", "price", "value", "score", "rating"], | |
| "name": "Tabular Regression", | |
| "target_hints": ["target", "price", "value", "score", "rating"] | |
| } | |
| } | |
| def find_similar( | |
| dataset_id: str, | |
| top_k: int = 5 | |
| ) -> str: | |
| """ | |
| Find datasets similar to a given dataset based on tags, domain, and structure. | |
| Use this tool to discover alternative or complementary datasets for your task. | |
| Similarity is based on shared tags, similar column structures, and domain overlap. | |
| Args: | |
| dataset_id: The dataset to find similar datasets for (e.g., "imdb", "squad") | |
| top_k: Number of similar datasets to return (1-10, default: 5) | |
| Returns: | |
| List of similar datasets with: | |
| - Dataset ID and download count | |
| - Similarity score (0-1) | |
| - Reason for similarity (shared tags, similar structure, etc.) | |
| How similarity is computed: | |
| - Tag overlap (same task categories, languages, domains) | |
| - Similar column names and structures | |
| - Same author/organization | |
| - Related task types | |
| """ | |
| top_k = max(1, min(10, top_k)) | |
| client = get_client() | |
| # Get info about the source dataset | |
| source_info = client.get_dataset_info(dataset_id) | |
| if "error" in source_info: | |
| return f"Error: Could not fetch info for dataset '{dataset_id}': {source_info.get('error')}" | |
| source_tags = set(source_info.get('tags', [])) | |
| # Get schema for structure comparison | |
| source_schema = client.get_schema(dataset_id) | |
| source_columns = set(source_schema.get('columns', [])) if "error" not in source_schema else set() | |
| # Extract key tags for search | |
| search_terms = [] | |
| for tag in source_tags: | |
| if ':' in tag: | |
| # Task category tags like "task_categories:text-classification" | |
| if tag.startswith('task_categories:'): | |
| search_terms.append(tag.split(':')[1]) | |
| elif tag.startswith('language:'): | |
| search_terms.append(tag.split(':')[1]) | |
| elif len(tag) > 2: | |
| search_terms.append(tag) | |
| # Search for candidates | |
| candidates = [] | |
| for term in search_terms[:3]: # Use top 3 terms | |
| results = client.search_datasets(term, limit=20) | |
| candidates.extend(results) | |
| # Remove duplicates and source dataset | |
| seen = {dataset_id} | |
| unique_candidates = [] | |
| for ds in candidates: | |
| if ds['id'] not in seen: | |
| seen.add(ds['id']) | |
| unique_candidates.append(ds) | |
| # Score candidates | |
| scored = [] | |
| for ds in unique_candidates[:30]: # Limit processing | |
| try: | |
| ds_info = client.get_dataset_info(ds['id']) | |
| ds_tags = set(ds_info.get('tags', [])) | |
| # Compute similarity score | |
| tag_overlap = len(source_tags & ds_tags) | |
| tag_score = tag_overlap / max(len(source_tags), 1) | |
| # Check column similarity | |
| ds_schema = client.get_schema(ds['id']) | |
| ds_columns = set(ds_schema.get('columns', [])) if "error" not in ds_schema else set() | |
| col_overlap = len(source_columns & ds_columns) | |
| col_score = col_overlap / max(len(source_columns), 1) if source_columns else 0 | |
| # Combined score | |
| similarity = (tag_score * 0.6) + (col_score * 0.4) | |
| # Determine reason | |
| reasons = [] | |
| if tag_overlap > 0: | |
| common_tags = list(source_tags & ds_tags)[:3] | |
| reasons.append(f"Shared tags: {', '.join(common_tags)}") | |
| if col_overlap > 0: | |
| common_cols = list(source_columns & ds_columns)[:3] | |
| reasons.append(f"Similar columns: {', '.join(common_cols)}") | |
| if ds_info.get('author') == source_info.get('author'): | |
| reasons.append("Same author") | |
| similarity += 0.1 | |
| if similarity > 0.1: | |
| scored.append({ | |
| "id": ds['id'], | |
| "downloads": ds.get('downloads', 0), | |
| "similarity_score": min(1.0, similarity), | |
| "reason": "; ".join(reasons) if reasons else "Related domain" | |
| }) | |
| except Exception: | |
| continue | |
| # Sort by similarity and return top_k | |
| scored.sort(key=lambda x: x['similarity_score'], reverse=True) | |
| return format_similar_datasets(scored[:top_k]) | |
| def suggest_tasks(dataset_id: str) -> str: | |
| """ | |
| Analyze a dataset and suggest suitable machine learning tasks. | |
| Use this tool to discover what ML tasks a dataset could be used for, | |
| based on its column structure, data types, and metadata. | |
| Args: | |
| dataset_id: The dataset to analyze (e.g., "imdb", "squad", "cnn_dailymail") | |
| Returns: | |
| List of suggested ML tasks with: | |
| - Task name and confidence level (high/medium/low) | |
| - Reasoning for the suggestion | |
| - Recommended target column | |
| - Recommended feature columns | |
| Task types detected: | |
| - Text Classification (sentiment, topic, intent) | |
| - Question Answering | |
| - Summarization | |
| - Translation | |
| - Image Classification | |
| - Named Entity Recognition | |
| - Token Classification | |
| - Text Generation | |
| - Tabular Classification/Regression | |
| """ | |
| client = get_client() | |
| # Get schema | |
| schema = client.get_schema(dataset_id) | |
| if "error" in schema: | |
| return format_task_suggestions({"error": f"Could not load schema: {schema['error']}"}) | |
| columns = [c.lower() for c in schema.get('columns', [])] | |
| features = schema.get('features', {}) | |
| # Get dataset info for tags | |
| info = client.get_dataset_info(dataset_id) | |
| tags = [t.lower() for t in info.get('tags', [])] if "error" not in info else [] | |
| suggestions: List[Dict[str, Any]] = [] | |
| for task_id, pattern in TASK_PATTERNS.items(): | |
| # Check column name matches | |
| pattern_cols = [c.lower() for c in pattern['columns']] | |
| matching_cols = [c for c in columns if any(p in c for p in pattern_cols)] | |
| # Check tag matches | |
| tag_match = any(task_id in t for t in tags) | |
| if matching_cols or tag_match: | |
| # Determine confidence | |
| if tag_match and len(matching_cols) >= 2: | |
| confidence = "high" | |
| elif tag_match or len(matching_cols) >= 2: | |
| confidence = "medium" | |
| else: | |
| confidence = "low" | |
| # Find target column | |
| target_hints = [h.lower() for h in pattern['target_hints']] | |
| target_col = None | |
| for col in columns: | |
| if any(hint in col for hint in target_hints): | |
| target_col = col | |
| break | |
| # Feature columns (all except target) | |
| feature_cols = [c for c in columns if c != target_col][:5] | |
| # Build reason | |
| reasons = [] | |
| if matching_cols: | |
| reasons.append(f"Found columns: {', '.join(matching_cols[:3])}") | |
| if tag_match: | |
| reasons.append("Dataset tags indicate this task") | |
| suggestions.append({ | |
| "name": pattern['name'], | |
| "confidence": confidence, | |
| "reason": "; ".join(reasons), | |
| "target_column": target_col, | |
| "feature_columns": feature_cols | |
| }) | |
| # Sort by confidence | |
| confidence_order = {"high": 0, "medium": 1, "low": 2} | |
| suggestions.sort(key=lambda x: confidence_order.get(x['confidence'], 3)) | |
| if not suggestions: | |
| # Generic suggestion based on column types | |
| has_text = any('string' in str(features.get(c, '')).lower() for c in schema.get('columns', [])) | |
| has_numeric = any('int' in str(features.get(c, '')).lower() or 'float' in str(features.get(c, '')).lower() | |
| for c in schema.get('columns', [])) | |
| if has_text: | |
| suggestions.append({ | |
| "name": "Text Analysis (Generic)", | |
| "confidence": "low", | |
| "reason": "Dataset contains text columns", | |
| "target_column": None, | |
| "feature_columns": columns[:5] | |
| }) | |
| if has_numeric: | |
| suggestions.append({ | |
| "name": "Regression/Classification (Generic)", | |
| "confidence": "low", | |
| "reason": "Dataset contains numeric columns", | |
| "target_column": columns[-1] if columns else None, | |
| "feature_columns": columns[:-1] if len(columns) > 1 else columns | |
| }) | |
| return format_task_suggestions({ | |
| "dataset_id": dataset_id, | |
| "tasks": suggestions[:5] # Return top 5 suggestions | |
| }) | |
| def compare_datasets( | |
| dataset_a: str, | |
| dataset_b: str | |
| ) -> str: | |
| """ | |
| Compare two datasets side-by-side to understand their differences. | |
| Use this tool when deciding between similar datasets or understanding | |
| how datasets differ in structure, size, and content. | |
| Args: | |
| dataset_a: First dataset ID to compare (e.g., "imdb") | |
| dataset_b: Second dataset ID to compare (e.g., "rotten_tomatoes") | |
| Returns: | |
| Comparison table showing: | |
| - Download and like counts | |
| - Number of columns | |
| - Column names (common and unique) | |
| - License information | |
| - Tags comparison | |
| - Data types comparison | |
| Use cases: | |
| - Choosing between similar datasets for a task | |
| - Understanding dataset versions or variants | |
| - Finding complementary datasets | |
| """ | |
| client = get_client() | |
| # Get info for both datasets | |
| info_a = client.get_dataset_info(dataset_a) | |
| info_b = client.get_dataset_info(dataset_b) | |
| if "error" in info_a: | |
| return f"Error loading dataset A ({dataset_a}): {info_a.get('error')}" | |
| if "error" in info_b: | |
| return f"Error loading dataset B ({dataset_b}): {info_b.get('error')}" | |
| # Get schemas | |
| schema_a = client.get_schema(dataset_a) | |
| schema_b = client.get_schema(dataset_b) | |
| cols_a = set(schema_a.get('columns', [])) if "error" not in schema_a else set() | |
| cols_b = set(schema_b.get('columns', [])) if "error" not in schema_b else set() | |
| comparison = { | |
| "dataset_a": dataset_a, | |
| "dataset_b": dataset_b, | |
| "comparison": { | |
| "Downloads": { | |
| "a": f"{info_a.get('downloads', 0):,}", | |
| "b": f"{info_b.get('downloads', 0):,}" | |
| }, | |
| "Likes": { | |
| "a": str(info_a.get('likes', 0)), | |
| "b": str(info_b.get('likes', 0)) | |
| }, | |
| "License": { | |
| "a": info_a.get('license') or "N/A", | |
| "b": info_b.get('license') or "N/A" | |
| }, | |
| "Columns": { | |
| "a": str(len(cols_a)), | |
| "b": str(len(cols_b)) | |
| }, | |
| "Author": { | |
| "a": info_a.get('author') or "N/A", | |
| "b": info_b.get('author') or "N/A" | |
| } | |
| }, | |
| "common_columns": list(cols_a & cols_b), | |
| "unique_to_a": list(cols_a - cols_b), | |
| "unique_to_b": list(cols_b - cols_a) | |
| } | |
| # Compare tags | |
| tags_a = set(info_a.get('tags', [])) | |
| tags_b = set(info_b.get('tags', [])) | |
| common_tags = tags_a & tags_b | |
| if common_tags: | |
| comparison["comparison"]["Common Tags"] = { | |
| "a": str(len(common_tags)), | |
| "b": ", ".join(list(common_tags)[:3]) | |
| } | |
| return format_comparison(comparison) | |