NBA_Analysis / tools.py
shekkari21's picture
Add NBA analysis project files
ddabbe4
"""
Tools for CrewAI agents to interact with NBA data.
"""
import pandas as pd
from crewai.tools import tool
from typing import Optional
from vector_db import get_vector_db
def get_agent_tools(data_path: str):
"""
Get the list of tools available for agents.
Args:
data_path: Path to the CSV data file
Returns:
list: List of tools for agents to use
"""
# Define helper functions first, then wrap them with @tool
def _read_nba_data(limit: int = 10) -> str:
"""Read a sample of the NBA data file to understand its structure."""
try:
# Read only a sample to avoid token limits
df = pd.read_csv(data_path)
sample = df.head(limit)
return f"Dataset: {len(df):,} total records, {len(df.columns)} columns\n\nColumn names: {', '.join(df.columns.tolist())}\n\nSample (first {limit} rows):\n\n{sample.to_string()}"
except Exception as e:
return f"Error reading file {data_path}: {str(e)}"
def _search_nba_data(
query: Optional[str] = None,
column: Optional[str] = None,
value: Optional[str] = None,
limit: int = 100
) -> str:
"""Search and filter NBA data CSV file."""
try:
df = pd.read_csv(data_path)
# Apply filters if provided
if column and value:
if column in df.columns:
df = df[df[column].astype(str).str.contains(str(value), case=False, na=False)]
else:
return f"Column '{column}' not found. Available columns: {', '.join(df.columns.tolist())}"
if query:
# Search across all string columns
mask = pd.Series([False] * len(df))
for col in df.columns:
if df[col].dtype == 'object':
mask |= df[col].astype(str).str.contains(query, case=False, na=False)
df = df[mask]
# Limit results to prevent token overflow
limit = min(limit, 50) # Cap at 50 rows
df = df.head(limit)
if len(df) == 0:
return "No matching records found."
# Truncate output if too large
result_str = df.to_string()
if len(result_str) > 2000:
result_str = df.head(20).to_string() + f"\n\n... (showing first 20 of {len(df)} matching records)"
return f"Found {len(df)} matching records:\n\n{result_str}"
except Exception as e:
return f"Error searching CSV {data_path}: {str(e)}"
def _get_nba_data_summary() -> str:
"""Get a concise summary of the NBA data file."""
try:
df = pd.read_csv(data_path)
# Calculate basic stats - keep it concise
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
summary = f"""NBA Dataset Summary:
- Total Records: {len(df):,}
- Columns: {len(df.columns)} ({', '.join(df.columns.tolist()[:10])}{'...' if len(df.columns) > 10 else ''})
- Unique Players: {df['Player'].nunique() if 'Player' in df.columns else 'N/A'}
- Unique Teams: {df['Tm'].nunique() if 'Tm' in df.columns else 'N/A'}
- Date Range: {df['Data'].min() if 'Data' in df.columns else 'N/A'} to {df['Data'].max() if 'Data' in df.columns else 'N/A'}
- Key Numeric Columns: {', '.join(numeric_cols[:10]) if numeric_cols else 'None'}
Sample (first 3 rows):
{df.head(3).to_string()}
"""
return summary
except Exception as e:
return f"Error getting CSV summary for {data_path}: {str(e)}"
# Now wrap them with @tool decorator
@tool("read_nba_data")
def read_nba_data(limit: int = 10) -> str:
"""
Read a sample of the NBA data file to understand its structure.
Use this to see column names and data format, NOT for full analysis.
Args:
limit: Number of sample rows to return (default: 10, max: 50)
"""
limit = min(limit, 50) # Cap at 50 rows
return _read_nba_data(limit)
@tool("search_nba_data")
def search_nba_data(
query: Optional[str] = None,
column: Optional[str] = None,
value: Optional[str] = None,
limit: int = 100
) -> str:
"""
Search and filter NBA data CSV file. Use this to find specific players, teams, or statistics.
Args:
query: Optional text query to search for in any column (e.g., player name, team name)
column: Optional column name to filter by (e.g., 'Player', 'Tm', 'PTS')
value: Optional value to match in the specified column
limit: Maximum number of rows to return (default: 100)
"""
return _search_nba_data(query, column, value, limit)
@tool("get_nba_data_summary")
def get_nba_data_summary() -> str:
"""
Get a comprehensive summary of the NBA data file including structure, basic statistics,
and data quality information. Use this first to understand the dataset.
"""
return _get_nba_data_summary()
def _semantic_search_nba_data(query: str, n_results: int = 10) -> str:
"""
Perform semantic search on NBA data using vector embeddings.
This understands natural language queries and finds semantically similar records.
"""
try:
# Get vector database instance
vector_db = get_vector_db(data_path)
# Perform semantic search
results = vector_db.search(query, n_results=n_results)
if not results:
return f"No results found for query: '{query}'"
# Format results
output = [f"Semantic search results for: '{query}'\n"]
output.append(f"Found {len(results)} similar records:\n")
output.append("=" * 80 + "\n")
# Load original CSV to get full row data
df = pd.read_csv(data_path)
for i, result in enumerate(results, 1):
metadata = result['metadata']
similarity = result['similarity']
row_index = metadata.get('row_index', -1)
output.append(f"\nResult {i} (Similarity: {similarity:.3f}):")
output.append(f"Document: {result['document']}\n")
# Get full row data if available
if row_index >= 0 and row_index < len(df):
row = df.iloc[row_index]
output.append("Full record:")
output.append(row.to_string())
output.append("\n" + "-" * 80 + "\n")
return "\n".join(output)
except Exception as e:
return f"Error performing semantic search: {str(e)}"
@tool("semantic_search_nba_data")
def semantic_search_nba_data(query: str, n_results: int = 10) -> str:
"""
Perform semantic search on NBA data using vector embeddings and natural language understanding.
This tool understands the meaning of your query, not just exact text matches.
Use this for natural language questions like:
- "high scoring games"
- "LeBron James best performances"
- "games with many assists"
- "efficient shooters"
- "close games"
Args:
query: Natural language query describing what you're looking for
n_results: Number of results to return (default: 10, max: 50)
Examples:
semantic_search_nba_data("LeBron James high scoring games")
semantic_search_nba_data("games with triple doubles", n_results=5)
semantic_search_nba_data("most efficient three point shooters")
"""
# Limit n_results to prevent overwhelming output
n_results = min(n_results, 50)
return _semantic_search_nba_data(query, n_results)
def _analyze_nba_data(pandas_code: str) -> str:
"""
Execute pandas operations on NBA data for advanced analysis.
This tool allows you to perform aggregations, groupby, sorting, filtering, etc.
The pandas code should work with a DataFrame variable named 'df'.
You can use any pandas operations like:
- df.groupby('Player')['3P'].sum().sort_values(ascending=False).head(5)
- df.groupby('Player').agg({'PTS': 'sum', 'AST': 'sum'}).sort_values('PTS', ascending=False)
- df[df['3P'] > 5].groupby('Player')['3P'].sum().nlargest(5)
"""
try:
# Load the CSV data
df = pd.read_csv(data_path)
# Execute the pandas code in a safe environment
# Create a namespace with only pandas and the dataframe
namespace = {
'pd': pd,
'df': df,
'__builtins__': __builtins__
}
# Execute the code
exec(f"result = {pandas_code}", namespace)
result = namespace.get('result')
# Convert result to string representation - limit size to avoid token limits
if isinstance(result, pd.DataFrame):
# Limit DataFrame output to prevent token overflow
if len(result) > 50:
result_str = f"{result.head(50).to_string()}\n\n... (showing first 50 of {len(result)} rows)"
else:
result_str = result.to_string()
return f"Analysis Result ({result.shape[0]} rows, {result.shape[1]} cols):\n\n{result_str}"
elif isinstance(result, pd.Series):
# Limit Series output
if len(result) > 50:
result_str = f"{result.head(50).to_string()}\n\n... (showing first 50 of {len(result)} items)"
else:
result_str = result.to_string()
return f"Analysis Result ({len(result)} items):\n\n{result_str}"
else:
# For other types, limit string length
result_str = str(result)
if len(result_str) > 2000:
result_str = result_str[:2000] + "\n\n... (truncated)"
return f"Analysis Result:\n\n{result_str}"
except Exception as e:
return f"Error executing pandas code: {str(e)}\n\nMake sure your code uses 'df' as the DataFrame variable and returns a result."
@tool("analyze_nba_data")
def analyze_nba_data(pandas_code: str) -> str:
"""
Execute pandas operations on NBA data for advanced analysis, aggregations, and statistical queries.
This is the PRIMARY tool for data analysis tasks like:
- Finding top players by statistics (groupby + aggregation + sorting)
- Calculating totals, averages, counts per player/team
- Filtering and aggregating data
- Statistical analysis
IMPORTANT: Use this tool for queries that require:
- Aggregating data (sum, mean, count, etc.)
- Grouping by player, team, etc.
- Finding top N results
- Calculating totals or averages
Args:
pandas_code: Valid pandas code that operates on a DataFrame variable named 'df'
The code should return a result (DataFrame, Series, or value)
Examples:
# Top 5 players by total 3-pointers made
analyze_nba_data("df.groupby('Player')['3P'].sum().sort_values(ascending=False).head(5)")
# Top 10 players by total points
analyze_nba_data("df.groupby('Player')['PTS'].sum().sort_values(ascending=False).head(10)")
# Players with highest 3-point percentage (minimum 100 attempts)
analyze_nba_data("df[df['3PA'] >= 100].groupby('Player').agg({'3P': 'sum', '3PA': 'sum'}).assign(percentage=lambda x: x['3P']/x['3PA']*100).sort_values('percentage', ascending=False).head(5)")
# Top 5 players by assists
analyze_nba_data("df.groupby('Player')['AST'].sum().sort_values(ascending=False).head(5)")
# Team win rates
analyze_nba_data("df.groupby('Tm')['Res'].apply(lambda x: (x == 'W').sum() / len(x) * 100).sort_values(ascending=False)")
"""
return _analyze_nba_data(pandas_code)
return [read_nba_data, search_nba_data, get_nba_data_summary, semantic_search_nba_data, analyze_nba_data]