Spaces:
Sleeping
Sleeping
| """ | |
| Tools for CrewAI agents to interact with NBA data. | |
| """ | |
| import pandas as pd | |
| from crewai.tools import tool | |
| from typing import Optional | |
| from vector_db import get_vector_db | |
| def get_agent_tools(data_path: str): | |
| """ | |
| Get the list of tools available for agents. | |
| Args: | |
| data_path: Path to the CSV data file | |
| Returns: | |
| list: List of tools for agents to use | |
| """ | |
| # Define helper functions first, then wrap them with @tool | |
| def _read_nba_data(limit: int = 10) -> str: | |
| """Read a sample of the NBA data file to understand its structure.""" | |
| try: | |
| # Read only a sample to avoid token limits | |
| df = pd.read_csv(data_path) | |
| sample = df.head(limit) | |
| return f"Dataset: {len(df):,} total records, {len(df.columns)} columns\n\nColumn names: {', '.join(df.columns.tolist())}\n\nSample (first {limit} rows):\n\n{sample.to_string()}" | |
| except Exception as e: | |
| return f"Error reading file {data_path}: {str(e)}" | |
| def _search_nba_data( | |
| query: Optional[str] = None, | |
| column: Optional[str] = None, | |
| value: Optional[str] = None, | |
| limit: int = 100 | |
| ) -> str: | |
| """Search and filter NBA data CSV file.""" | |
| try: | |
| df = pd.read_csv(data_path) | |
| # Apply filters if provided | |
| if column and value: | |
| if column in df.columns: | |
| df = df[df[column].astype(str).str.contains(str(value), case=False, na=False)] | |
| else: | |
| return f"Column '{column}' not found. Available columns: {', '.join(df.columns.tolist())}" | |
| if query: | |
| # Search across all string columns | |
| mask = pd.Series([False] * len(df)) | |
| for col in df.columns: | |
| if df[col].dtype == 'object': | |
| mask |= df[col].astype(str).str.contains(query, case=False, na=False) | |
| df = df[mask] | |
| # Limit results to prevent token overflow | |
| limit = min(limit, 50) # Cap at 50 rows | |
| df = df.head(limit) | |
| if len(df) == 0: | |
| return "No matching records found." | |
| # Truncate output if too large | |
| result_str = df.to_string() | |
| if len(result_str) > 2000: | |
| result_str = df.head(20).to_string() + f"\n\n... (showing first 20 of {len(df)} matching records)" | |
| return f"Found {len(df)} matching records:\n\n{result_str}" | |
| except Exception as e: | |
| return f"Error searching CSV {data_path}: {str(e)}" | |
| def _get_nba_data_summary() -> str: | |
| """Get a concise summary of the NBA data file.""" | |
| try: | |
| df = pd.read_csv(data_path) | |
| # Calculate basic stats - keep it concise | |
| numeric_cols = df.select_dtypes(include=['number']).columns.tolist() | |
| summary = f"""NBA Dataset Summary: | |
| - Total Records: {len(df):,} | |
| - Columns: {len(df.columns)} ({', '.join(df.columns.tolist()[:10])}{'...' if len(df.columns) > 10 else ''}) | |
| - Unique Players: {df['Player'].nunique() if 'Player' in df.columns else 'N/A'} | |
| - Unique Teams: {df['Tm'].nunique() if 'Tm' in df.columns else 'N/A'} | |
| - Date Range: {df['Data'].min() if 'Data' in df.columns else 'N/A'} to {df['Data'].max() if 'Data' in df.columns else 'N/A'} | |
| - Key Numeric Columns: {', '.join(numeric_cols[:10]) if numeric_cols else 'None'} | |
| Sample (first 3 rows): | |
| {df.head(3).to_string()} | |
| """ | |
| return summary | |
| except Exception as e: | |
| return f"Error getting CSV summary for {data_path}: {str(e)}" | |
| # Now wrap them with @tool decorator | |
| def read_nba_data(limit: int = 10) -> str: | |
| """ | |
| Read a sample of the NBA data file to understand its structure. | |
| Use this to see column names and data format, NOT for full analysis. | |
| Args: | |
| limit: Number of sample rows to return (default: 10, max: 50) | |
| """ | |
| limit = min(limit, 50) # Cap at 50 rows | |
| return _read_nba_data(limit) | |
| def search_nba_data( | |
| query: Optional[str] = None, | |
| column: Optional[str] = None, | |
| value: Optional[str] = None, | |
| limit: int = 100 | |
| ) -> str: | |
| """ | |
| Search and filter NBA data CSV file. Use this to find specific players, teams, or statistics. | |
| Args: | |
| query: Optional text query to search for in any column (e.g., player name, team name) | |
| column: Optional column name to filter by (e.g., 'Player', 'Tm', 'PTS') | |
| value: Optional value to match in the specified column | |
| limit: Maximum number of rows to return (default: 100) | |
| """ | |
| return _search_nba_data(query, column, value, limit) | |
| def get_nba_data_summary() -> str: | |
| """ | |
| Get a comprehensive summary of the NBA data file including structure, basic statistics, | |
| and data quality information. Use this first to understand the dataset. | |
| """ | |
| return _get_nba_data_summary() | |
| def _semantic_search_nba_data(query: str, n_results: int = 10) -> str: | |
| """ | |
| Perform semantic search on NBA data using vector embeddings. | |
| This understands natural language queries and finds semantically similar records. | |
| """ | |
| try: | |
| # Get vector database instance | |
| vector_db = get_vector_db(data_path) | |
| # Perform semantic search | |
| results = vector_db.search(query, n_results=n_results) | |
| if not results: | |
| return f"No results found for query: '{query}'" | |
| # Format results | |
| output = [f"Semantic search results for: '{query}'\n"] | |
| output.append(f"Found {len(results)} similar records:\n") | |
| output.append("=" * 80 + "\n") | |
| # Load original CSV to get full row data | |
| df = pd.read_csv(data_path) | |
| for i, result in enumerate(results, 1): | |
| metadata = result['metadata'] | |
| similarity = result['similarity'] | |
| row_index = metadata.get('row_index', -1) | |
| output.append(f"\nResult {i} (Similarity: {similarity:.3f}):") | |
| output.append(f"Document: {result['document']}\n") | |
| # Get full row data if available | |
| if row_index >= 0 and row_index < len(df): | |
| row = df.iloc[row_index] | |
| output.append("Full record:") | |
| output.append(row.to_string()) | |
| output.append("\n" + "-" * 80 + "\n") | |
| return "\n".join(output) | |
| except Exception as e: | |
| return f"Error performing semantic search: {str(e)}" | |
| def semantic_search_nba_data(query: str, n_results: int = 10) -> str: | |
| """ | |
| Perform semantic search on NBA data using vector embeddings and natural language understanding. | |
| This tool understands the meaning of your query, not just exact text matches. | |
| Use this for natural language questions like: | |
| - "high scoring games" | |
| - "LeBron James best performances" | |
| - "games with many assists" | |
| - "efficient shooters" | |
| - "close games" | |
| Args: | |
| query: Natural language query describing what you're looking for | |
| n_results: Number of results to return (default: 10, max: 50) | |
| Examples: | |
| semantic_search_nba_data("LeBron James high scoring games") | |
| semantic_search_nba_data("games with triple doubles", n_results=5) | |
| semantic_search_nba_data("most efficient three point shooters") | |
| """ | |
| # Limit n_results to prevent overwhelming output | |
| n_results = min(n_results, 50) | |
| return _semantic_search_nba_data(query, n_results) | |
| def _analyze_nba_data(pandas_code: str) -> str: | |
| """ | |
| Execute pandas operations on NBA data for advanced analysis. | |
| This tool allows you to perform aggregations, groupby, sorting, filtering, etc. | |
| The pandas code should work with a DataFrame variable named 'df'. | |
| You can use any pandas operations like: | |
| - df.groupby('Player')['3P'].sum().sort_values(ascending=False).head(5) | |
| - df.groupby('Player').agg({'PTS': 'sum', 'AST': 'sum'}).sort_values('PTS', ascending=False) | |
| - df[df['3P'] > 5].groupby('Player')['3P'].sum().nlargest(5) | |
| """ | |
| try: | |
| # Load the CSV data | |
| df = pd.read_csv(data_path) | |
| # Execute the pandas code in a safe environment | |
| # Create a namespace with only pandas and the dataframe | |
| namespace = { | |
| 'pd': pd, | |
| 'df': df, | |
| '__builtins__': __builtins__ | |
| } | |
| # Execute the code | |
| exec(f"result = {pandas_code}", namespace) | |
| result = namespace.get('result') | |
| # Convert result to string representation - limit size to avoid token limits | |
| if isinstance(result, pd.DataFrame): | |
| # Limit DataFrame output to prevent token overflow | |
| if len(result) > 50: | |
| result_str = f"{result.head(50).to_string()}\n\n... (showing first 50 of {len(result)} rows)" | |
| else: | |
| result_str = result.to_string() | |
| return f"Analysis Result ({result.shape[0]} rows, {result.shape[1]} cols):\n\n{result_str}" | |
| elif isinstance(result, pd.Series): | |
| # Limit Series output | |
| if len(result) > 50: | |
| result_str = f"{result.head(50).to_string()}\n\n... (showing first 50 of {len(result)} items)" | |
| else: | |
| result_str = result.to_string() | |
| return f"Analysis Result ({len(result)} items):\n\n{result_str}" | |
| else: | |
| # For other types, limit string length | |
| result_str = str(result) | |
| if len(result_str) > 2000: | |
| result_str = result_str[:2000] + "\n\n... (truncated)" | |
| return f"Analysis Result:\n\n{result_str}" | |
| except Exception as e: | |
| return f"Error executing pandas code: {str(e)}\n\nMake sure your code uses 'df' as the DataFrame variable and returns a result." | |
| def analyze_nba_data(pandas_code: str) -> str: | |
| """ | |
| Execute pandas operations on NBA data for advanced analysis, aggregations, and statistical queries. | |
| This is the PRIMARY tool for data analysis tasks like: | |
| - Finding top players by statistics (groupby + aggregation + sorting) | |
| - Calculating totals, averages, counts per player/team | |
| - Filtering and aggregating data | |
| - Statistical analysis | |
| IMPORTANT: Use this tool for queries that require: | |
| - Aggregating data (sum, mean, count, etc.) | |
| - Grouping by player, team, etc. | |
| - Finding top N results | |
| - Calculating totals or averages | |
| Args: | |
| pandas_code: Valid pandas code that operates on a DataFrame variable named 'df' | |
| The code should return a result (DataFrame, Series, or value) | |
| Examples: | |
| # Top 5 players by total 3-pointers made | |
| analyze_nba_data("df.groupby('Player')['3P'].sum().sort_values(ascending=False).head(5)") | |
| # Top 10 players by total points | |
| analyze_nba_data("df.groupby('Player')['PTS'].sum().sort_values(ascending=False).head(10)") | |
| # Players with highest 3-point percentage (minimum 100 attempts) | |
| analyze_nba_data("df[df['3PA'] >= 100].groupby('Player').agg({'3P': 'sum', '3PA': 'sum'}).assign(percentage=lambda x: x['3P']/x['3PA']*100).sort_values('percentage', ascending=False).head(5)") | |
| # Top 5 players by assists | |
| analyze_nba_data("df.groupby('Player')['AST'].sum().sort_values(ascending=False).head(5)") | |
| # Team win rates | |
| analyze_nba_data("df.groupby('Tm')['Res'].apply(lambda x: (x == 'W').sum() / len(x) * 100).sort_values(ascending=False)") | |
| """ | |
| return _analyze_nba_data(pandas_code) | |
| return [read_nba_data, search_nba_data, get_nba_data_summary, semantic_search_nba_data, analyze_nba_data] | |