Spaces:
Running
Running
Commit
·
ddabbe4
1
Parent(s):
2796eed
Add NBA analysis project files
Browse files- .DS_Store +0 -0
- .gitattributes +1 -0
- __pycache__/agents.cpython-311.pyc +0 -0
- __pycache__/config.cpython-311.pyc +0 -0
- __pycache__/crew.cpython-311.pyc +0 -0
- __pycache__/data_service.cpython-311.pyc +0 -0
- __pycache__/gradio_app.cpython-311.pyc +0 -0
- __pycache__/pandas_query_generator.cpython-311.pyc +0 -0
- __pycache__/tasks.cpython-311.pyc +0 -0
- __pycache__/tools.cpython-311.pyc +0 -0
- __pycache__/vector_db.cpython-311.pyc +0 -0
- agents.py +94 -0
- chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/data_level0.bin +3 -0
- chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/header.bin +3 -0
- chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/index_metadata.pickle +3 -0
- chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/length.bin +3 -0
- chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/link_lists.bin +3 -0
- chroma_db/chroma.sqlite3 +3 -0
- chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/data_level0.bin +3 -0
- chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/header.bin +3 -0
- chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/index_metadata.pickle +3 -0
- chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/length.bin +3 -0
- chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/link_lists.bin +3 -0
- chroma_db_60989696/chroma.sqlite3 +3 -0
- config.py +37 -0
- crew.py +161 -0
- crew_gradio_app.py +340 -0
- main.py +54 -0
- nba24-25.csv +0 -0
- pyproject.toml +18 -0
- tasks.py +177 -0
- tools.py +294 -0
- uv.lock +0 -0
- vector_db.py +233 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
__pycache__/agents.cpython-311.pyc
ADDED
|
Binary file (4.44 kB). View file
|
|
|
__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (1.18 kB). View file
|
|
|
__pycache__/crew.cpython-311.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
__pycache__/data_service.cpython-311.pyc
ADDED
|
Binary file (8.65 kB). View file
|
|
|
__pycache__/gradio_app.cpython-311.pyc
ADDED
|
Binary file (9.81 kB). View file
|
|
|
__pycache__/pandas_query_generator.cpython-311.pyc
ADDED
|
Binary file (6.64 kB). View file
|
|
|
__pycache__/tasks.cpython-311.pyc
ADDED
|
Binary file (8.25 kB). View file
|
|
|
__pycache__/tools.cpython-311.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
__pycache__/vector_db.cpython-311.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
agents.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent definitions for NBA data analysis.
|
| 3 |
+
"""
|
| 4 |
+
from crewai import Agent
|
| 5 |
+
from config import get_llm, NBA_DATA_PATH
|
| 6 |
+
from tools import get_agent_tools
|
| 7 |
+
|
| 8 |
+
# Get LLM
|
| 9 |
+
llm = get_llm()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_engineer_agent(csv_path: str = None) -> Agent:
|
| 13 |
+
"""
|
| 14 |
+
Create the Engineer Agent for data processing and engineering tasks.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
csv_path: Path to CSV file (defaults to NBA_DATA_PATH from config)
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Agent: Configured Engineer Agent
|
| 21 |
+
"""
|
| 22 |
+
data_path = csv_path or NBA_DATA_PATH
|
| 23 |
+
agent_tools = get_agent_tools(data_path)
|
| 24 |
+
|
| 25 |
+
return Agent(
|
| 26 |
+
role="Data Engineer",
|
| 27 |
+
goal="Process, clean, and prepare data for analysis. Ensure data quality and create structured datasets.",
|
| 28 |
+
backstory="""You are an expert data engineer with years of experience in sports analytics.
|
| 29 |
+
You specialize in processing large datasets, handling missing values, data validation,
|
| 30 |
+
and creating clean, analysis-ready datasets. You understand statistics deeply and
|
| 31 |
+
know how to structure data for optimal analysis.""",
|
| 32 |
+
verbose=True,
|
| 33 |
+
allow_delegation=False,
|
| 34 |
+
llm=llm,
|
| 35 |
+
tools=agent_tools,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def create_analyst_agent(csv_path: str = None) -> Agent:
|
| 40 |
+
"""
|
| 41 |
+
Create the Analyst Agent for data analysis and insights.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
csv_path: Path to CSV file (defaults to NBA_DATA_PATH from config)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Agent: Configured Analyst Agent
|
| 48 |
+
"""
|
| 49 |
+
data_path = csv_path or NBA_DATA_PATH
|
| 50 |
+
agent_tools = get_agent_tools(data_path)
|
| 51 |
+
|
| 52 |
+
return Agent(
|
| 53 |
+
role="Data Analyst",
|
| 54 |
+
goal="Analyze data to extract meaningful insights, identify patterns, and provide actionable recommendations.",
|
| 55 |
+
backstory="""You are a seasoned data analyst with a passion for analytics.
|
| 56 |
+
You excel at finding patterns in data, identifying trends, performing statistical analysis,
|
| 57 |
+
and translating complex data into clear, actionable insights. You understand performance
|
| 58 |
+
metrics and can provide strategic recommendations based on data.
|
| 59 |
+
|
| 60 |
+
CRITICAL: When asked for aggregations, top N lists, totals, or statistical summaries:
|
| 61 |
+
- ALWAYS use the 'analyze_nba_data' tool with pandas groupby operations
|
| 62 |
+
- NEVER use semantic_search_nba_data for aggregation queries (it only returns individual records)
|
| 63 |
+
- For "top 5 three-point shooters": use analyze_nba_data with groupby('Player')['3P'].sum()
|
| 64 |
+
- Plan your analysis: understand what aggregation is needed, then write the appropriate pandas code""",
|
| 65 |
+
verbose=True,
|
| 66 |
+
allow_delegation=False,
|
| 67 |
+
llm=llm,
|
| 68 |
+
tools=agent_tools,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def create_storyteller_agent() -> Agent:
|
| 73 |
+
"""
|
| 74 |
+
Create the Storyteller Agent for creating engaging headlines and storylines from analysis results.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Agent: Configured Storyteller Agent
|
| 78 |
+
"""
|
| 79 |
+
return Agent(
|
| 80 |
+
role="Sports Storyteller",
|
| 81 |
+
goal="Transform data analysis results into engaging headlines and compelling storylines that bring statistics to life with narrative and context.",
|
| 82 |
+
backstory="""You are a creative sports journalist and storyteller with a talent for turning
|
| 83 |
+
statistical analysis into captivating headlines and engaging storylines. You know how to make data come alive,
|
| 84 |
+
creating headlines that grab attention and writing compelling content that tells the story behind the numbers.
|
| 85 |
+
You understand what makes a great sports story and can transform complex analytics into memorable narratives
|
| 86 |
+
that connect with readers. You write with flair, using vivid language and storytelling techniques to make
|
| 87 |
+
statistics human and relatable. Your stories provide context, explain why the data matters, and bring the
|
| 88 |
+
performance metrics to life with engaging prose.""",
|
| 89 |
+
verbose=True,
|
| 90 |
+
allow_delegation=False,
|
| 91 |
+
llm=llm,
|
| 92 |
+
tools=[], # Storyteller doesn't need data tools, just creates headlines and content from analysis
|
| 93 |
+
)
|
| 94 |
+
|
chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/data_level0.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a112f6302c6253d29bf56fb3c6ec8bce06a42c1280c68cd188b0f20fd844ca0
|
| 3 |
+
size 26816000
|
chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/header.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab7363962786093545f9f11a09f4f1be05bb332e1c6be866de187f06c2c1e1ee
|
| 3 |
+
size 100
|
chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/index_metadata.pickle
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20c208f1ed5df94e05538452d26434e5a132a0b5fa3a48676a2b0366fabacc2e
|
| 3 |
+
size 585972
|
chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/length.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae88008596e5479e720bbb6df80965d87d90d63d7069a181cca5f1845a259c5e
|
| 3 |
+
size 64000
|
chroma_db/cd14ed18-2502-4d85-a6d6-0038801d4f09/link_lists.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:43183fb18a5e080f6047847af4607dc735e205e5e1d4bb3a4604fe8a1902bc06
|
| 3 |
+
size 137236
|
chroma_db/chroma.sqlite3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa397f322217a7936050de2a8516db57cc0b5af827310b0a9afaabdb0970e57d
|
| 3 |
+
size 32964608
|
chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/data_level0.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff1c6206f9e1d892c5e34c2ab74be32e697abe510cb46aa610def9c0878eb72d
|
| 3 |
+
size 26816000
|
chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/header.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eec2601f0d7ccd7fab235b9ff5a7406658d2f5cbdd92c9f5bbc9dc15ee413488
|
| 3 |
+
size 100
|
chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/index_metadata.pickle
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18945c3f6187592d88b56bc58e2d139778b1ae5b2de54f10c402fef73c0e5cc3
|
| 3 |
+
size 585972
|
chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/length.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6dfbb34b09a5f8e3aee64166d912438913ab132a3f5bf30c48153364ecf22f1f
|
| 3 |
+
size 64000
|
chroma_db_60989696/26d64802-2814-432f-b752-5aea2fb05a32/link_lists.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa6d71cbedff22c85700a09a0344d7101fc3984ec3a7a4eba026e63c4ab02740
|
| 3 |
+
size 137236
|
chroma_db_60989696/chroma.sqlite3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8bf62bf70948fe9d94a394e1ea5c728ca7cb88cecc9aef0f210c9ee0d1495a07
|
| 3 |
+
size 33001472
|
config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for the NBA data analysis project.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from crewai import LLM
|
| 6 |
+
|
| 7 |
+
# NBA Data Configuration
|
| 8 |
+
NBA_DATA_PATH = "nba24-25.csv"
|
| 9 |
+
|
| 10 |
+
# OpenAI Configuration
|
| 11 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 12 |
+
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o")
|
| 13 |
+
|
| 14 |
+
# Validate OpenAI API key (only raise error when actually trying to use LLM, not on import)
|
| 15 |
+
# This allows the app to load even if API key isn't set yet
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_llm() -> LLM:
|
| 19 |
+
"""
|
| 20 |
+
Create and return a CrewAI LLM instance configured for OpenAI.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
LLM: Configured CrewAI LLM instance for OpenAI
|
| 24 |
+
|
| 25 |
+
Raises:
|
| 26 |
+
ValueError: If OPENAI_API_KEY is not set
|
| 27 |
+
"""
|
| 28 |
+
if not OPENAI_API_KEY:
|
| 29 |
+
raise ValueError(
|
| 30 |
+
"OPENAI_API_KEY environment variable is not set. "
|
| 31 |
+
"Please set it using: export OPENAI_API_KEY='your-api-key'"
|
| 32 |
+
)
|
| 33 |
+
return LLM(
|
| 34 |
+
model=OPENAI_MODEL,
|
| 35 |
+
api_key=OPENAI_API_KEY
|
| 36 |
+
)
|
| 37 |
+
|
crew.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Crew setup for NBA data analysis workflow.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
from crewai import Crew, Process
|
| 7 |
+
from agents import create_engineer_agent, create_analyst_agent, create_storyteller_agent
|
| 8 |
+
from tasks import create_data_engineering_task, create_data_analysis_task, create_custom_analysis_task, create_storyteller_task
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_crew() -> Crew:
|
| 12 |
+
"""
|
| 13 |
+
Create and configure the CrewAI crew with agents and tasks.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
Crew: Configured CrewAI crew ready for execution
|
| 17 |
+
"""
|
| 18 |
+
# Create agents
|
| 19 |
+
engineer_agent = create_engineer_agent()
|
| 20 |
+
analyst_agent = create_analyst_agent()
|
| 21 |
+
|
| 22 |
+
# Create tasks
|
| 23 |
+
data_engineering_task = create_data_engineering_task(engineer_agent)
|
| 24 |
+
data_analysis_task = create_data_analysis_task(analyst_agent, data_engineering_task)
|
| 25 |
+
|
| 26 |
+
# Create and return the crew
|
| 27 |
+
return Crew(
|
| 28 |
+
agents=[engineer_agent, analyst_agent],
|
| 29 |
+
tasks=[data_engineering_task, data_analysis_task],
|
| 30 |
+
process=Process.sequential,
|
| 31 |
+
verbose=True,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def create_crew_with_custom_task(user_query: str, csv_path: str = None) -> Crew:
|
| 36 |
+
"""
|
| 37 |
+
Create a CrewAI crew with engineering task, custom analyst task, and storyteller task.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
user_query: The user's custom analysis query/task
|
| 41 |
+
csv_path: Optional path to CSV file (if None, uses default from config)
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Crew: Configured CrewAI crew ready for execution
|
| 45 |
+
"""
|
| 46 |
+
# Create agents (they will use the csv_path from tools)
|
| 47 |
+
engineer_agent = create_engineer_agent(csv_path)
|
| 48 |
+
analyst_agent = create_analyst_agent(csv_path)
|
| 49 |
+
storyteller_agent = create_storyteller_agent()
|
| 50 |
+
|
| 51 |
+
# Create engineering task (fixed)
|
| 52 |
+
data_engineering_task = create_data_engineering_task(engineer_agent, csv_path)
|
| 53 |
+
|
| 54 |
+
# Create custom analyst task from user input (no dependency on engineer task for parallel execution)
|
| 55 |
+
custom_analysis_task = create_custom_analysis_task(analyst_agent, user_query, None, csv_path)
|
| 56 |
+
|
| 57 |
+
# Create storyteller task that uses the analyst's output
|
| 58 |
+
storyteller_task = create_storyteller_task(storyteller_agent, custom_analysis_task)
|
| 59 |
+
|
| 60 |
+
# Create and return the crew
|
| 61 |
+
return Crew(
|
| 62 |
+
agents=[engineer_agent, analyst_agent, storyteller_agent],
|
| 63 |
+
tasks=[data_engineering_task, custom_analysis_task, storyteller_task],
|
| 64 |
+
process=Process.sequential,
|
| 65 |
+
verbose=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def create_flow_crew(user_query: str, csv_path: str) -> Crew:
|
| 70 |
+
"""
|
| 71 |
+
Create a single crew with parallel tasks (Engineer and Analyst) that merge results at the end.
|
| 72 |
+
This satisfies the assignment requirement: "Parallelize tasks via a Flow; merge results at the end."
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
user_query: The user's custom analysis query/task
|
| 76 |
+
csv_path: Path to the uploaded CSV file
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Crew: Single crew with parallel tasks that will merge results
|
| 80 |
+
"""
|
| 81 |
+
# Create all agents
|
| 82 |
+
engineer_agent = create_engineer_agent(csv_path)
|
| 83 |
+
analyst_agent = create_analyst_agent(csv_path)
|
| 84 |
+
storyteller_agent = create_storyteller_agent()
|
| 85 |
+
|
| 86 |
+
# Create tasks WITHOUT dependencies so they can run in parallel
|
| 87 |
+
# Engineer task - independent
|
| 88 |
+
data_engineering_task = create_data_engineering_task(engineer_agent, csv_path)
|
| 89 |
+
|
| 90 |
+
# Analyst task - independent (no dependency on engineer for parallel execution)
|
| 91 |
+
custom_analysis_task = create_custom_analysis_task(analyst_agent, user_query, None, csv_path)
|
| 92 |
+
|
| 93 |
+
# Storyteller task - depends on analyst (runs after analyst completes)
|
| 94 |
+
storyteller_task = create_storyteller_task(storyteller_agent, custom_analysis_task)
|
| 95 |
+
|
| 96 |
+
# Create a single crew with all tasks
|
| 97 |
+
# Tasks without dependencies will run in parallel
|
| 98 |
+
# Storyteller will run after analyst completes
|
| 99 |
+
return Crew(
|
| 100 |
+
agents=[engineer_agent, analyst_agent, storyteller_agent],
|
| 101 |
+
tasks=[data_engineering_task, custom_analysis_task, storyteller_task],
|
| 102 |
+
process=Process.sequential, # CrewAI will parallelize independent tasks automatically
|
| 103 |
+
verbose=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def create_analysis_only_crew(user_query: str, csv_path: str) -> Crew:
|
| 108 |
+
"""
|
| 109 |
+
Create a crew with only Analyst and Storyteller agents (no Engineer).
|
| 110 |
+
Used when engineer results are already available and user asks a new question.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
user_query: The user's custom analysis query/task
|
| 114 |
+
csv_path: Path to the uploaded CSV file
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Crew: Crew with only analyst and storyteller tasks
|
| 118 |
+
"""
|
| 119 |
+
# Create only analyst and storyteller agents
|
| 120 |
+
analyst_agent = create_analyst_agent(csv_path)
|
| 121 |
+
storyteller_agent = create_storyteller_agent()
|
| 122 |
+
|
| 123 |
+
# Create analyst task with user query
|
| 124 |
+
custom_analysis_task = create_custom_analysis_task(analyst_agent, user_query, None, csv_path)
|
| 125 |
+
|
| 126 |
+
# Storyteller task depends on analyst
|
| 127 |
+
storyteller_task = create_storyteller_task(storyteller_agent, custom_analysis_task)
|
| 128 |
+
|
| 129 |
+
return Crew(
|
| 130 |
+
agents=[analyst_agent, storyteller_agent],
|
| 131 |
+
tasks=[custom_analysis_task, storyteller_task],
|
| 132 |
+
process=Process.sequential,
|
| 133 |
+
verbose=True,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def create_analyst_only_crew(user_query: str, csv_path: str) -> Crew:
|
| 138 |
+
"""
|
| 139 |
+
Create a crew with only Analyst agent (no Engineer, no Storyteller).
|
| 140 |
+
Used for specific user questions where only analysis is needed.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
user_query: The user's custom analysis query/task
|
| 144 |
+
csv_path: Path to the uploaded CSV file
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Crew: Crew with only analyst task
|
| 148 |
+
"""
|
| 149 |
+
# Create only analyst agent
|
| 150 |
+
analyst_agent = create_analyst_agent(csv_path)
|
| 151 |
+
|
| 152 |
+
# Create analyst task with user query
|
| 153 |
+
custom_analysis_task = create_custom_analysis_task(analyst_agent, user_query, None, csv_path)
|
| 154 |
+
|
| 155 |
+
return Crew(
|
| 156 |
+
agents=[analyst_agent],
|
| 157 |
+
tasks=[custom_analysis_task],
|
| 158 |
+
process=Process.sequential,
|
| 159 |
+
verbose=True,
|
| 160 |
+
)
|
| 161 |
+
|
crew_gradio_app.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Minimal Gradio app for CrewAI data analysis with file upload and parallel agent execution.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import traceback
|
| 7 |
+
from crew import create_flow_crew, create_analyst_only_crew
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def process_file_and_analyze(file, user_query: str = "", engineer_result: str = None) -> tuple[str, str]:
|
| 11 |
+
"""
|
| 12 |
+
Process uploaded file and run all agents (Engineer, Analyst, Storyteller), then merge results.
|
| 13 |
+
Used for the "Analyze Dataset" button.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
file: Uploaded file object
|
| 17 |
+
user_query: The user's analysis query/task (empty for general analysis)
|
| 18 |
+
engineer_result: Previously computed engineer result (if available)
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
tuple: (merged_results, engineer_result) - engineer_result is stored for reuse
|
| 22 |
+
"""
|
| 23 |
+
if file is None:
|
| 24 |
+
return "Please upload a CSV file.", engineer_result or ""
|
| 25 |
+
|
| 26 |
+
# Use default analysis if no query provided
|
| 27 |
+
if not user_query or not user_query.strip():
|
| 28 |
+
user_query = "Provide a comprehensive analysis of the dataset including: top performers, key statistics, interesting patterns, and notable insights."
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
# Get file path
|
| 32 |
+
file_path = file.name if hasattr(file, 'name') else str(file)
|
| 33 |
+
csv_path = file_path
|
| 34 |
+
|
| 35 |
+
# Full analysis: run all agents
|
| 36 |
+
crew = create_flow_crew(user_query.strip(), csv_path)
|
| 37 |
+
result = crew.kickoff()
|
| 38 |
+
|
| 39 |
+
merged_output = []
|
| 40 |
+
stored_engineer_result = ""
|
| 41 |
+
|
| 42 |
+
# Get engineer result (first task)
|
| 43 |
+
if hasattr(result, 'tasks_output') and result.tasks_output:
|
| 44 |
+
if len(result.tasks_output) >= 1:
|
| 45 |
+
engineer_output = str(result.tasks_output[0])
|
| 46 |
+
stored_engineer_result = engineer_output
|
| 47 |
+
merged_output.append("## Engineer Agent Results")
|
| 48 |
+
merged_output.append("")
|
| 49 |
+
merged_output.append(engineer_output)
|
| 50 |
+
merged_output.append("")
|
| 51 |
+
merged_output.append("---")
|
| 52 |
+
merged_output.append("")
|
| 53 |
+
|
| 54 |
+
# Get analyst result (second task)
|
| 55 |
+
if hasattr(result, 'tasks_output') and result.tasks_output:
|
| 56 |
+
if len(result.tasks_output) >= 2:
|
| 57 |
+
analyst_output = str(result.tasks_output[1])
|
| 58 |
+
merged_output.append("## Analyst Agent Results")
|
| 59 |
+
merged_output.append("")
|
| 60 |
+
merged_output.append(analyst_output)
|
| 61 |
+
merged_output.append("")
|
| 62 |
+
merged_output.append("---")
|
| 63 |
+
merged_output.append("")
|
| 64 |
+
|
| 65 |
+
# Get storyteller result (third task)
|
| 66 |
+
if hasattr(result, 'tasks_output') and result.tasks_output:
|
| 67 |
+
if len(result.tasks_output) >= 3:
|
| 68 |
+
storyteller_output = str(result.tasks_output[2])
|
| 69 |
+
merged_output.append("## Storyteller Agent Results")
|
| 70 |
+
merged_output.append("")
|
| 71 |
+
merged_output.append(storyteller_output)
|
| 72 |
+
merged_output.append("")
|
| 73 |
+
|
| 74 |
+
# If we couldn't extract from tasks_output, use the full result
|
| 75 |
+
if not merged_output:
|
| 76 |
+
merged_output.append("## Complete Analysis Results")
|
| 77 |
+
merged_output.append("")
|
| 78 |
+
merged_output.append(str(result))
|
| 79 |
+
|
| 80 |
+
return "\n".join(merged_output), stored_engineer_result
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
error_trace = traceback.format_exc()
|
| 84 |
+
error_msg = f"Error: {str(e)}\n\nTraceback:\n{error_trace}"
|
| 85 |
+
print(error_msg)
|
| 86 |
+
return error_msg, engineer_result or ""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def process_question_only(file, user_query: str) -> str:
|
| 90 |
+
"""
|
| 91 |
+
Process a specific user question using only the Analyst agent (no Engineer, no Storyteller).
|
| 92 |
+
Used for the "Analyze with Question" button.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
file: Uploaded file object
|
| 96 |
+
user_query: The user's specific analysis question
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
str: Analyst results only
|
| 100 |
+
"""
|
| 101 |
+
if file is None:
|
| 102 |
+
return "Please upload a CSV file."
|
| 103 |
+
|
| 104 |
+
if not user_query or not user_query.strip():
|
| 105 |
+
return "Please enter a question."
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
# Get file path
|
| 109 |
+
file_path = file.name if hasattr(file, 'name') else str(file)
|
| 110 |
+
csv_path = file_path
|
| 111 |
+
|
| 112 |
+
# Run only analyst
|
| 113 |
+
crew = create_analyst_only_crew(user_query.strip(), csv_path)
|
| 114 |
+
result = crew.kickoff()
|
| 115 |
+
|
| 116 |
+
# Get analyst result
|
| 117 |
+
if hasattr(result, 'tasks_output') and result.tasks_output:
|
| 118 |
+
if len(result.tasks_output) >= 1:
|
| 119 |
+
analyst_output = str(result.tasks_output[0])
|
| 120 |
+
return analyst_output
|
| 121 |
+
|
| 122 |
+
# Fallback to full result
|
| 123 |
+
return str(result)
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
error_trace = traceback.format_exc()
|
| 127 |
+
error_msg = f"Error: {str(e)}\n\nTraceback:\n{error_trace}"
|
| 128 |
+
print(error_msg)
|
| 129 |
+
return error_msg
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def create_app():
|
| 133 |
+
"""Create and return the Gradio interface."""
|
| 134 |
+
with gr.Blocks(title="NBA Stats Analysis with CrewAI", theme=gr.themes.Soft()) as app:
|
| 135 |
+
gr.Markdown("""
|
| 136 |
+
# NBA Stats Analysis with CrewAI
|
| 137 |
+
|
| 138 |
+
Upload your NBA statistics CSV file to get comprehensive analysis with engaging storylines.
|
| 139 |
+
|
| 140 |
+
**How it works:**
|
| 141 |
+
- **Engineer Agent**: Examines and validates your dataset
|
| 142 |
+
- **Analyst Agent**: Performs deep analysis (general or based on your question)
|
| 143 |
+
- **Storyteller Agent**: Creates headlines and compelling storylines
|
| 144 |
+
|
| 145 |
+
All agents work in parallel and results are merged for you!
|
| 146 |
+
""")
|
| 147 |
+
|
| 148 |
+
# Store engineer result in state
|
| 149 |
+
engineer_state = gr.State(value="")
|
| 150 |
+
|
| 151 |
+
with gr.Row():
|
| 152 |
+
with gr.Column(scale=1):
|
| 153 |
+
file_input = gr.File(
|
| 154 |
+
label="Upload CSV File",
|
| 155 |
+
file_types=[".csv"],
|
| 156 |
+
type="filepath"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
analyze_btn = gr.Button(
|
| 160 |
+
"Analyze Dataset",
|
| 161 |
+
variant="primary",
|
| 162 |
+
size="lg",
|
| 163 |
+
visible=False
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
gr.Markdown("### Ask a Specific Question")
|
| 167 |
+
|
| 168 |
+
query_input = gr.Textbox(
|
| 169 |
+
label="Your Analysis Question",
|
| 170 |
+
placeholder="e.g., 'Who are the top 5 three-point shooters?' or 'Analyze the best players by assists'",
|
| 171 |
+
lines=2
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
question_output = gr.Markdown(
|
| 175 |
+
value="",
|
| 176 |
+
label="Answer",
|
| 177 |
+
visible=False
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
query_btn = gr.Button(
|
| 181 |
+
"Analyze with Question",
|
| 182 |
+
variant="secondary",
|
| 183 |
+
size="lg"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
with gr.Row():
|
| 187 |
+
with gr.Column():
|
| 188 |
+
status_output = gr.Markdown(
|
| 189 |
+
value="",
|
| 190 |
+
label="Agent Status",
|
| 191 |
+
visible=False
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
with gr.Row():
|
| 195 |
+
with gr.Column():
|
| 196 |
+
merged_output = gr.Markdown(
|
| 197 |
+
value="**Ready to analyze!** Upload a CSV file above, then click 'Analyze Dataset' to get started.",
|
| 198 |
+
label="Full Analysis Results"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def show_loading_animation(is_question: bool = False):
|
| 202 |
+
"""Show loading animation while processing."""
|
| 203 |
+
if is_question:
|
| 204 |
+
return """## Analysis in Progress...
|
| 205 |
+
|
| 206 |
+
<div style="text-align: center; padding: 20px;">
|
| 207 |
+
<div style="font-size: 18px; margin-bottom: 15px;">
|
| 208 |
+
<strong>Analyzing your question...</strong>
|
| 209 |
+
</div>
|
| 210 |
+
<div style="display: flex; justify-content: center; max-width: 600px; margin: 0 auto;">
|
| 211 |
+
<div style="text-align: center; margin: 10px;">
|
| 212 |
+
<div style="font-size: 14px; font-weight: bold;">Analyst Agent</div>
|
| 213 |
+
<div style="font-size: 12px; color: #666; margin-top: 5px;">Processing query...</div>
|
| 214 |
+
</div>
|
| 215 |
+
</div>
|
| 216 |
+
<div style="margin-top: 25px; font-size: 14px; color: #888;">
|
| 217 |
+
This may take a moment... Please wait while the agent processes your question.
|
| 218 |
+
</div>
|
| 219 |
+
</div>"""
|
| 220 |
+
else:
|
| 221 |
+
return """## Analysis in Progress...
|
| 222 |
+
|
| 223 |
+
<div style="text-align: center; padding: 20px;">
|
| 224 |
+
<div style="font-size: 18px; margin-bottom: 15px;">
|
| 225 |
+
<strong>Agents are working in parallel...</strong>
|
| 226 |
+
</div>
|
| 227 |
+
<div style="display: flex; justify-content: space-around; max-width: 600px; margin: 0 auto; flex-wrap: wrap;">
|
| 228 |
+
<div style="text-align: center; margin: 10px;">
|
| 229 |
+
<div style="font-size: 14px; font-weight: bold;">Engineer Agent</div>
|
| 230 |
+
<div style="font-size: 12px; color: #666; margin-top: 5px;">Examining dataset...</div>
|
| 231 |
+
</div>
|
| 232 |
+
<div style="text-align: center; margin: 10px;">
|
| 233 |
+
<div style="font-size: 14px; font-weight: bold;">Analyst Agent</div>
|
| 234 |
+
<div style="font-size: 12px; color: #666; margin-top: 5px;">Analyzing data...</div>
|
| 235 |
+
</div>
|
| 236 |
+
<div style="text-align: center; margin: 10px;">
|
| 237 |
+
<div style="font-size: 14px; font-weight: bold;">Storyteller Agent</div>
|
| 238 |
+
<div style="font-size: 12px; color: #666; margin-top: 5px;">Creating storylines...</div>
|
| 239 |
+
</div>
|
| 240 |
+
</div>
|
| 241 |
+
<div style="margin-top: 25px; font-size: 14px; color: #888;">
|
| 242 |
+
This may take a moment... Please wait while the agents process your data.
|
| 243 |
+
</div>
|
| 244 |
+
</div>"""
|
| 245 |
+
|
| 246 |
+
def on_file_upload(file):
|
| 247 |
+
"""Handle file upload - show analyze button and reset state."""
|
| 248 |
+
if file is not None:
|
| 249 |
+
return gr.update(visible=True), ""
|
| 250 |
+
return gr.update(visible=False), ""
|
| 251 |
+
|
| 252 |
+
def start_full_analysis(file, engineer_result: str = ""):
|
| 253 |
+
"""Start full analysis and show loading animation."""
|
| 254 |
+
loading_msg = show_loading_animation(is_question=False)
|
| 255 |
+
return gr.update(visible=True, value=loading_msg), gr.update(value="")
|
| 256 |
+
|
| 257 |
+
def complete_full_analysis(file, engineer_result: str = ""):
|
| 258 |
+
"""Complete full analysis and return results."""
|
| 259 |
+
result, new_engineer_result = process_file_and_analyze(file, "", engineer_result)
|
| 260 |
+
if result.startswith("Error:") or result.startswith("Please upload"):
|
| 261 |
+
result = f"### {result}"
|
| 262 |
+
return result, gr.update(visible=False), new_engineer_result
|
| 263 |
+
|
| 264 |
+
def start_question_analysis(file, user_query: str = ""):
|
| 265 |
+
"""Start question analysis and show loading animation."""
|
| 266 |
+
loading_msg = show_loading_animation(is_question=True)
|
| 267 |
+
return gr.update(visible=True, value=loading_msg), gr.update(visible=True, value="")
|
| 268 |
+
|
| 269 |
+
def complete_question_analysis(file, user_query: str = ""):
|
| 270 |
+
"""Complete question analysis and return results."""
|
| 271 |
+
result = process_question_only(file, user_query)
|
| 272 |
+
if result.startswith("Error:") or result.startswith("Please"):
|
| 273 |
+
result = f"### {result}"
|
| 274 |
+
else:
|
| 275 |
+
# Format the answer in a highlighted box
|
| 276 |
+
result = f"""<div style="background-color: #f0f7ff; border: 2px solid #4a90e2; border-radius: 8px; padding: 15px; margin: 10px 0;">
|
| 277 |
+
{result}
|
| 278 |
+
</div>"""
|
| 279 |
+
return result, gr.update(visible=False)
|
| 280 |
+
|
| 281 |
+
# When file is uploaded, show analyze button and reset engineer state
|
| 282 |
+
file_input.change(
|
| 283 |
+
fn=on_file_upload,
|
| 284 |
+
inputs=[file_input],
|
| 285 |
+
outputs=[analyze_btn, engineer_state]
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Analyze button - runs general analysis (no query needed)
|
| 289 |
+
analyze_btn.click(
|
| 290 |
+
fn=start_full_analysis,
|
| 291 |
+
inputs=[file_input, engineer_state],
|
| 292 |
+
outputs=[status_output, merged_output]
|
| 293 |
+
).then(
|
| 294 |
+
fn=complete_full_analysis,
|
| 295 |
+
inputs=[file_input, engineer_state],
|
| 296 |
+
outputs=[merged_output, status_output, engineer_state]
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Query button - runs analysis with user's question (only Analyst)
|
| 300 |
+
query_btn.click(
|
| 301 |
+
fn=start_question_analysis,
|
| 302 |
+
inputs=[file_input, query_input],
|
| 303 |
+
outputs=[status_output, question_output]
|
| 304 |
+
).then(
|
| 305 |
+
fn=complete_question_analysis,
|
| 306 |
+
inputs=[file_input, query_input],
|
| 307 |
+
outputs=[question_output, status_output]
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Allow Enter key to submit query
|
| 311 |
+
query_input.submit(
|
| 312 |
+
fn=start_question_analysis,
|
| 313 |
+
inputs=[file_input, query_input],
|
| 314 |
+
outputs=[status_output, question_output]
|
| 315 |
+
).then(
|
| 316 |
+
fn=complete_question_analysis,
|
| 317 |
+
inputs=[file_input, query_input],
|
| 318 |
+
outputs=[question_output, status_output]
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
return app
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
if __name__ == "__main__":
|
| 325 |
+
try:
|
| 326 |
+
print("Creating Gradio app...")
|
| 327 |
+
app = create_app()
|
| 328 |
+
print("Launching Gradio app...")
|
| 329 |
+
app.launch(
|
| 330 |
+
server_name="0.0.0.0",
|
| 331 |
+
server_port=7860,
|
| 332 |
+
share=False,
|
| 333 |
+
show_error=True
|
| 334 |
+
)
|
| 335 |
+
except Exception as e:
|
| 336 |
+
print(f"Error launching app: {e}")
|
| 337 |
+
traceback.print_exc()
|
| 338 |
+
raise
|
| 339 |
+
|
| 340 |
+
|
main.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main entry point for NBA 2024-25 data analysis using CrewAI.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from config import NBA_DATA_PATH
|
| 7 |
+
from crew import create_crew
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
"""Main function to run the NBA data analysis crew."""
|
| 12 |
+
print("=" * 60)
|
| 13 |
+
print("NBA 2024-25 Data Analysis with CrewAI")
|
| 14 |
+
print("Using OpenAI")
|
| 15 |
+
print("=" * 60)
|
| 16 |
+
print()
|
| 17 |
+
|
| 18 |
+
# Check if data file exists
|
| 19 |
+
if not os.path.exists(NBA_DATA_PATH):
|
| 20 |
+
print(f"Error: {NBA_DATA_PATH} not found!")
|
| 21 |
+
return
|
| 22 |
+
|
| 23 |
+
print(f"Loading data from {NBA_DATA_PATH}...")
|
| 24 |
+
try:
|
| 25 |
+
# Quick data preview
|
| 26 |
+
df = pd.read_csv(NBA_DATA_PATH)
|
| 27 |
+
print(f"Dataset loaded: {len(df)} records, {len(df.columns)} columns")
|
| 28 |
+
print(f"Columns: {', '.join(df.columns.tolist())}")
|
| 29 |
+
print()
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Error loading data: {e}")
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
print("Starting CrewAI agents...")
|
| 35 |
+
print("Engineer Agent will process and clean the data...")
|
| 36 |
+
print("Analyst Agent will analyze the data for insights...")
|
| 37 |
+
print()
|
| 38 |
+
print("-" * 60)
|
| 39 |
+
print()
|
| 40 |
+
|
| 41 |
+
# Create and execute the crew
|
| 42 |
+
crew = create_crew()
|
| 43 |
+
result = crew.kickoff()
|
| 44 |
+
|
| 45 |
+
print()
|
| 46 |
+
print("=" * 60)
|
| 47 |
+
print("ANALYSIS COMPLETE")
|
| 48 |
+
print("=" * 60)
|
| 49 |
+
print()
|
| 50 |
+
print(result)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
nba24-25.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "msml610project"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"crewai>=1.4.1",
|
| 9 |
+
"crewai-flow>=0.1.0",
|
| 10 |
+
"ipykernel>=7.1.0",
|
| 11 |
+
"openai>=1.0.0",
|
| 12 |
+
"pandas>=2.0.0",
|
| 13 |
+
"litellm>=1.0.0",
|
| 14 |
+
"sentence-transformers>=2.2.0",
|
| 15 |
+
"chromadb>=0.4.0",
|
| 16 |
+
"gradio>=4.0.0",
|
| 17 |
+
"typer>=0.20.0",
|
| 18 |
+
]
|
tasks.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task definitions for NBA data analysis workflow.
|
| 3 |
+
"""
|
| 4 |
+
from crewai import Task
|
| 5 |
+
from config import NBA_DATA_PATH
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_data_engineering_task(engineer_agent, csv_path: str = None) -> Task:
|
| 9 |
+
"""
|
| 10 |
+
Create the data engineering task for processing and cleaning data.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
engineer_agent: The Engineer Agent to assign this task to
|
| 14 |
+
csv_path: Path to CSV file (defaults to NBA_DATA_PATH from config)
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Task: Configured data engineering task
|
| 18 |
+
"""
|
| 19 |
+
data_path = csv_path or NBA_DATA_PATH
|
| 20 |
+
|
| 21 |
+
return Task(
|
| 22 |
+
description=f"""
|
| 23 |
+
Quickly examine the dataset located at {data_path}.
|
| 24 |
+
|
| 25 |
+
Your tasks (BE EFFICIENT - use tools only once):
|
| 26 |
+
1. Get a brief summary of the dataset structure (use get_nba_data_summary ONCE)
|
| 27 |
+
2. Note the key columns available
|
| 28 |
+
3. Verify the data is ready for analysis
|
| 29 |
+
|
| 30 |
+
IMPORTANT:
|
| 31 |
+
- Use get_nba_data_summary ONCE only - it provides all needed info
|
| 32 |
+
- Do NOT call read_nba_data or analyze_nba_data multiple times
|
| 33 |
+
- Keep your report concise (2-3 sentences)
|
| 34 |
+
- The data is already clean and ready for analysis
|
| 35 |
+
|
| 36 |
+
Provide a brief confirmation that the dataset is loaded and ready for analysis.
|
| 37 |
+
""",
|
| 38 |
+
agent=engineer_agent,
|
| 39 |
+
expected_output="A brief confirmation (2-3 sentences) that the dataset is loaded and ready for analysis, including key column names."
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_data_analysis_task(analyst_agent, data_engineering_task: Task) -> Task:
|
| 44 |
+
"""
|
| 45 |
+
Create the data analysis task for extracting insights from NBA data.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
analyst_agent: The Analyst Agent to assign this task to
|
| 49 |
+
data_engineering_task: The data engineering task for context
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Task: Configured data analysis task
|
| 53 |
+
"""
|
| 54 |
+
return Task(
|
| 55 |
+
description=f"""
|
| 56 |
+
Using the cleaned NBA 2024-25 dataset, perform comprehensive analysis:
|
| 57 |
+
|
| 58 |
+
Your tasks:
|
| 59 |
+
1. Analyze player performance metrics:
|
| 60 |
+
- Top performers by points, assists, rebounds
|
| 61 |
+
- Shooting efficiency analysis (FG%, 3P%, FT%)
|
| 62 |
+
- Player efficiency ratings
|
| 63 |
+
2. Team performance analysis:
|
| 64 |
+
- Win/loss records by team
|
| 65 |
+
- Team offensive and defensive statistics
|
| 66 |
+
- Team performance trends
|
| 67 |
+
3. Game insights:
|
| 68 |
+
- High-scoring games
|
| 69 |
+
- Close games vs blowouts
|
| 70 |
+
- Performance by date/period
|
| 71 |
+
4. Identify key patterns and trends:
|
| 72 |
+
- Best performing players
|
| 73 |
+
- Most efficient teams
|
| 74 |
+
- Statistical outliers
|
| 75 |
+
5. Provide actionable insights and recommendations
|
| 76 |
+
|
| 77 |
+
Create a comprehensive analysis report with key findings and insights.
|
| 78 |
+
""",
|
| 79 |
+
agent=analyst_agent,
|
| 80 |
+
expected_output="A detailed analysis report with key insights, statistical findings, top performers, team analysis, and actionable recommendations based on the NBA 2024-25 data.",
|
| 81 |
+
context=[data_engineering_task]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def create_custom_analysis_task(analyst_agent, user_query: str, data_engineering_task: Task = None, csv_path: str = None) -> Task:
|
| 86 |
+
"""
|
| 87 |
+
Create a custom data analysis task based on user input.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
analyst_agent: The Analyst Agent to assign this task to
|
| 91 |
+
user_query: The user's custom analysis query/task
|
| 92 |
+
data_engineering_task: The data engineering task for context (optional for parallel execution)
|
| 93 |
+
csv_path: Path to CSV file (for reference in description)
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Task: Configured custom analysis task
|
| 97 |
+
"""
|
| 98 |
+
data_path = csv_path or NBA_DATA_PATH
|
| 99 |
+
context = [data_engineering_task] if data_engineering_task else []
|
| 100 |
+
|
| 101 |
+
return Task(
|
| 102 |
+
description=f"""
|
| 103 |
+
Using the dataset located at {data_path}, perform the following analysis as requested by the user:
|
| 104 |
+
|
| 105 |
+
{user_query}
|
| 106 |
+
|
| 107 |
+
IMPORTANT INSTRUCTIONS:
|
| 108 |
+
1. For queries requiring aggregations (sum, count, average, top N, etc.), you MUST use the 'analyze_nba_data' tool.
|
| 109 |
+
2. The 'analyze_nba_data' tool allows you to execute pandas code for grouping, aggregating, sorting, and filtering.
|
| 110 |
+
3. Examples of when to use 'analyze_nba_data':
|
| 111 |
+
- Finding top players by statistics (e.g., "top 5 three-point shooters")
|
| 112 |
+
- Calculating totals or averages per player/team
|
| 113 |
+
- Grouping and aggregating data
|
| 114 |
+
- Statistical analysis requiring groupby operations
|
| 115 |
+
4. Use 'semantic_search_nba_data' only for finding specific game records or examples, NOT for aggregations.
|
| 116 |
+
5. Plan your analysis: First understand what data you need, then use the appropriate tool to get aggregated results.
|
| 117 |
+
|
| 118 |
+
Steps to follow:
|
| 119 |
+
1. If the query asks for "top N" or aggregations, use analyze_nba_data with pandas groupby operations
|
| 120 |
+
2. For "top 5 three-point shooters": group by Player, sum the '3P' column, sort descending, take top 5
|
| 121 |
+
3. Present the results clearly with player names and their statistics
|
| 122 |
+
|
| 123 |
+
Provide a clear, comprehensive answer with relevant statistics, insights, and any supporting data from the dataset.
|
| 124 |
+
""",
|
| 125 |
+
agent=analyst_agent,
|
| 126 |
+
expected_output="A detailed analysis report addressing the user's query with relevant insights, statistics, and findings from the data.",
|
| 127 |
+
context=context
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def create_storyteller_task(storyteller_agent, analysis_task: Task) -> Task:
|
| 132 |
+
"""
|
| 133 |
+
Create a storyteller task that creates headlines and storylines from the analysis results.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
storyteller_agent: The Storyteller Agent to assign this task to
|
| 137 |
+
analysis_task: The analysis task whose output will be used to create headlines and content
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Task: Configured storyteller task
|
| 141 |
+
"""
|
| 142 |
+
return Task(
|
| 143 |
+
description="""
|
| 144 |
+
Review the data analysis results and create engaging headlines and compelling storylines that bring the data to life.
|
| 145 |
+
|
| 146 |
+
Your tasks:
|
| 147 |
+
1. Read and understand the analysis results thoroughly
|
| 148 |
+
2. Identify the most important and interesting findings
|
| 149 |
+
3. Create 3-5 compelling headlines that:
|
| 150 |
+
- Are catchy and attention-grabbing
|
| 151 |
+
- Accurately reflect the key insights
|
| 152 |
+
- Use engaging sports journalism language
|
| 153 |
+
- Are suitable for display to users
|
| 154 |
+
|
| 155 |
+
4. Write engaging storylines/content for each headline that:
|
| 156 |
+
- Tells a story about the findings
|
| 157 |
+
- Provides context and narrative around the statistics
|
| 158 |
+
- Makes the data come alive with compelling prose
|
| 159 |
+
- Explains why these insights matter
|
| 160 |
+
- Uses vivid language and storytelling techniques
|
| 161 |
+
- Is 2-3 paragraphs per storyline (enough to be engaging but concise)
|
| 162 |
+
|
| 163 |
+
5. Format your output as follows:
|
| 164 |
+
HEADLINES:
|
| 165 |
+
[List of 3-5 headlines, one per line]
|
| 166 |
+
|
| 167 |
+
STORYLINES:
|
| 168 |
+
[For each headline, write 2-3 paragraphs of engaging content that tells the story behind the data]
|
| 169 |
+
|
| 170 |
+
Make both the headlines and storylines exciting, memorable, and true to the data insights.
|
| 171 |
+
Write like a sports journalist who knows how to make statistics compelling and human.
|
| 172 |
+
""",
|
| 173 |
+
agent=storyteller_agent,
|
| 174 |
+
expected_output="A formatted output with 3-5 engaging headlines followed by detailed storylines (2-3 paragraphs each) that bring the data analysis to life with compelling narrative and context.",
|
| 175 |
+
context=[analysis_task]
|
| 176 |
+
)
|
| 177 |
+
|
tools.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tools for CrewAI agents to interact with NBA data.
|
| 3 |
+
"""
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from crewai.tools import tool
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from vector_db import get_vector_db
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_agent_tools(data_path: str):
|
| 11 |
+
"""
|
| 12 |
+
Get the list of tools available for agents.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
data_path: Path to the CSV data file
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
list: List of tools for agents to use
|
| 19 |
+
"""
|
| 20 |
+
# Define helper functions first, then wrap them with @tool
|
| 21 |
+
|
| 22 |
+
def _read_nba_data(limit: int = 10) -> str:
|
| 23 |
+
"""Read a sample of the NBA data file to understand its structure."""
|
| 24 |
+
try:
|
| 25 |
+
# Read only a sample to avoid token limits
|
| 26 |
+
df = pd.read_csv(data_path)
|
| 27 |
+
sample = df.head(limit)
|
| 28 |
+
return f"Dataset: {len(df):,} total records, {len(df.columns)} columns\n\nColumn names: {', '.join(df.columns.tolist())}\n\nSample (first {limit} rows):\n\n{sample.to_string()}"
|
| 29 |
+
except Exception as e:
|
| 30 |
+
return f"Error reading file {data_path}: {str(e)}"
|
| 31 |
+
|
| 32 |
+
def _search_nba_data(
|
| 33 |
+
query: Optional[str] = None,
|
| 34 |
+
column: Optional[str] = None,
|
| 35 |
+
value: Optional[str] = None,
|
| 36 |
+
limit: int = 100
|
| 37 |
+
) -> str:
|
| 38 |
+
"""Search and filter NBA data CSV file."""
|
| 39 |
+
try:
|
| 40 |
+
df = pd.read_csv(data_path)
|
| 41 |
+
|
| 42 |
+
# Apply filters if provided
|
| 43 |
+
if column and value:
|
| 44 |
+
if column in df.columns:
|
| 45 |
+
df = df[df[column].astype(str).str.contains(str(value), case=False, na=False)]
|
| 46 |
+
else:
|
| 47 |
+
return f"Column '{column}' not found. Available columns: {', '.join(df.columns.tolist())}"
|
| 48 |
+
|
| 49 |
+
if query:
|
| 50 |
+
# Search across all string columns
|
| 51 |
+
mask = pd.Series([False] * len(df))
|
| 52 |
+
for col in df.columns:
|
| 53 |
+
if df[col].dtype == 'object':
|
| 54 |
+
mask |= df[col].astype(str).str.contains(query, case=False, na=False)
|
| 55 |
+
df = df[mask]
|
| 56 |
+
|
| 57 |
+
# Limit results to prevent token overflow
|
| 58 |
+
limit = min(limit, 50) # Cap at 50 rows
|
| 59 |
+
df = df.head(limit)
|
| 60 |
+
|
| 61 |
+
if len(df) == 0:
|
| 62 |
+
return "No matching records found."
|
| 63 |
+
|
| 64 |
+
# Truncate output if too large
|
| 65 |
+
result_str = df.to_string()
|
| 66 |
+
if len(result_str) > 2000:
|
| 67 |
+
result_str = df.head(20).to_string() + f"\n\n... (showing first 20 of {len(df)} matching records)"
|
| 68 |
+
|
| 69 |
+
return f"Found {len(df)} matching records:\n\n{result_str}"
|
| 70 |
+
except Exception as e:
|
| 71 |
+
return f"Error searching CSV {data_path}: {str(e)}"
|
| 72 |
+
|
| 73 |
+
def _get_nba_data_summary() -> str:
|
| 74 |
+
"""Get a concise summary of the NBA data file."""
|
| 75 |
+
try:
|
| 76 |
+
df = pd.read_csv(data_path)
|
| 77 |
+
|
| 78 |
+
# Calculate basic stats - keep it concise
|
| 79 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
| 80 |
+
|
| 81 |
+
summary = f"""NBA Dataset Summary:
|
| 82 |
+
- Total Records: {len(df):,}
|
| 83 |
+
- Columns: {len(df.columns)} ({', '.join(df.columns.tolist()[:10])}{'...' if len(df.columns) > 10 else ''})
|
| 84 |
+
- Unique Players: {df['Player'].nunique() if 'Player' in df.columns else 'N/A'}
|
| 85 |
+
- Unique Teams: {df['Tm'].nunique() if 'Tm' in df.columns else 'N/A'}
|
| 86 |
+
- Date Range: {df['Data'].min() if 'Data' in df.columns else 'N/A'} to {df['Data'].max() if 'Data' in df.columns else 'N/A'}
|
| 87 |
+
- Key Numeric Columns: {', '.join(numeric_cols[:10]) if numeric_cols else 'None'}
|
| 88 |
+
|
| 89 |
+
Sample (first 3 rows):
|
| 90 |
+
{df.head(3).to_string()}
|
| 91 |
+
"""
|
| 92 |
+
return summary
|
| 93 |
+
except Exception as e:
|
| 94 |
+
return f"Error getting CSV summary for {data_path}: {str(e)}"
|
| 95 |
+
|
| 96 |
+
# Now wrap them with @tool decorator
|
| 97 |
+
@tool("read_nba_data")
|
| 98 |
+
def read_nba_data(limit: int = 10) -> str:
|
| 99 |
+
"""
|
| 100 |
+
Read a sample of the NBA data file to understand its structure.
|
| 101 |
+
Use this to see column names and data format, NOT for full analysis.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
limit: Number of sample rows to return (default: 10, max: 50)
|
| 105 |
+
"""
|
| 106 |
+
limit = min(limit, 50) # Cap at 50 rows
|
| 107 |
+
return _read_nba_data(limit)
|
| 108 |
+
|
| 109 |
+
@tool("search_nba_data")
|
| 110 |
+
def search_nba_data(
|
| 111 |
+
query: Optional[str] = None,
|
| 112 |
+
column: Optional[str] = None,
|
| 113 |
+
value: Optional[str] = None,
|
| 114 |
+
limit: int = 100
|
| 115 |
+
) -> str:
|
| 116 |
+
"""
|
| 117 |
+
Search and filter NBA data CSV file. Use this to find specific players, teams, or statistics.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
query: Optional text query to search for in any column (e.g., player name, team name)
|
| 121 |
+
column: Optional column name to filter by (e.g., 'Player', 'Tm', 'PTS')
|
| 122 |
+
value: Optional value to match in the specified column
|
| 123 |
+
limit: Maximum number of rows to return (default: 100)
|
| 124 |
+
"""
|
| 125 |
+
return _search_nba_data(query, column, value, limit)
|
| 126 |
+
|
| 127 |
+
@tool("get_nba_data_summary")
|
| 128 |
+
def get_nba_data_summary() -> str:
|
| 129 |
+
"""
|
| 130 |
+
Get a comprehensive summary of the NBA data file including structure, basic statistics,
|
| 131 |
+
and data quality information. Use this first to understand the dataset.
|
| 132 |
+
"""
|
| 133 |
+
return _get_nba_data_summary()
|
| 134 |
+
|
| 135 |
+
def _semantic_search_nba_data(query: str, n_results: int = 10) -> str:
|
| 136 |
+
"""
|
| 137 |
+
Perform semantic search on NBA data using vector embeddings.
|
| 138 |
+
This understands natural language queries and finds semantically similar records.
|
| 139 |
+
"""
|
| 140 |
+
try:
|
| 141 |
+
# Get vector database instance
|
| 142 |
+
vector_db = get_vector_db(data_path)
|
| 143 |
+
|
| 144 |
+
# Perform semantic search
|
| 145 |
+
results = vector_db.search(query, n_results=n_results)
|
| 146 |
+
|
| 147 |
+
if not results:
|
| 148 |
+
return f"No results found for query: '{query}'"
|
| 149 |
+
|
| 150 |
+
# Format results
|
| 151 |
+
output = [f"Semantic search results for: '{query}'\n"]
|
| 152 |
+
output.append(f"Found {len(results)} similar records:\n")
|
| 153 |
+
output.append("=" * 80 + "\n")
|
| 154 |
+
|
| 155 |
+
# Load original CSV to get full row data
|
| 156 |
+
df = pd.read_csv(data_path)
|
| 157 |
+
|
| 158 |
+
for i, result in enumerate(results, 1):
|
| 159 |
+
metadata = result['metadata']
|
| 160 |
+
similarity = result['similarity']
|
| 161 |
+
row_index = metadata.get('row_index', -1)
|
| 162 |
+
|
| 163 |
+
output.append(f"\nResult {i} (Similarity: {similarity:.3f}):")
|
| 164 |
+
output.append(f"Document: {result['document']}\n")
|
| 165 |
+
|
| 166 |
+
# Get full row data if available
|
| 167 |
+
if row_index >= 0 and row_index < len(df):
|
| 168 |
+
row = df.iloc[row_index]
|
| 169 |
+
output.append("Full record:")
|
| 170 |
+
output.append(row.to_string())
|
| 171 |
+
output.append("\n" + "-" * 80 + "\n")
|
| 172 |
+
|
| 173 |
+
return "\n".join(output)
|
| 174 |
+
except Exception as e:
|
| 175 |
+
return f"Error performing semantic search: {str(e)}"
|
| 176 |
+
|
| 177 |
+
@tool("semantic_search_nba_data")
|
| 178 |
+
def semantic_search_nba_data(query: str, n_results: int = 10) -> str:
|
| 179 |
+
"""
|
| 180 |
+
Perform semantic search on NBA data using vector embeddings and natural language understanding.
|
| 181 |
+
This tool understands the meaning of your query, not just exact text matches.
|
| 182 |
+
|
| 183 |
+
Use this for natural language questions like:
|
| 184 |
+
- "high scoring games"
|
| 185 |
+
- "LeBron James best performances"
|
| 186 |
+
- "games with many assists"
|
| 187 |
+
- "efficient shooters"
|
| 188 |
+
- "close games"
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
query: Natural language query describing what you're looking for
|
| 192 |
+
n_results: Number of results to return (default: 10, max: 50)
|
| 193 |
+
|
| 194 |
+
Examples:
|
| 195 |
+
semantic_search_nba_data("LeBron James high scoring games")
|
| 196 |
+
semantic_search_nba_data("games with triple doubles", n_results=5)
|
| 197 |
+
semantic_search_nba_data("most efficient three point shooters")
|
| 198 |
+
"""
|
| 199 |
+
# Limit n_results to prevent overwhelming output
|
| 200 |
+
n_results = min(n_results, 50)
|
| 201 |
+
return _semantic_search_nba_data(query, n_results)
|
| 202 |
+
|
| 203 |
+
def _analyze_nba_data(pandas_code: str) -> str:
|
| 204 |
+
"""
|
| 205 |
+
Execute pandas operations on NBA data for advanced analysis.
|
| 206 |
+
This tool allows you to perform aggregations, groupby, sorting, filtering, etc.
|
| 207 |
+
|
| 208 |
+
The pandas code should work with a DataFrame variable named 'df'.
|
| 209 |
+
You can use any pandas operations like:
|
| 210 |
+
- df.groupby('Player')['3P'].sum().sort_values(ascending=False).head(5)
|
| 211 |
+
- df.groupby('Player').agg({'PTS': 'sum', 'AST': 'sum'}).sort_values('PTS', ascending=False)
|
| 212 |
+
- df[df['3P'] > 5].groupby('Player')['3P'].sum().nlargest(5)
|
| 213 |
+
"""
|
| 214 |
+
try:
|
| 215 |
+
# Load the CSV data
|
| 216 |
+
df = pd.read_csv(data_path)
|
| 217 |
+
|
| 218 |
+
# Execute the pandas code in a safe environment
|
| 219 |
+
# Create a namespace with only pandas and the dataframe
|
| 220 |
+
namespace = {
|
| 221 |
+
'pd': pd,
|
| 222 |
+
'df': df,
|
| 223 |
+
'__builtins__': __builtins__
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
# Execute the code
|
| 227 |
+
exec(f"result = {pandas_code}", namespace)
|
| 228 |
+
result = namespace.get('result')
|
| 229 |
+
|
| 230 |
+
# Convert result to string representation - limit size to avoid token limits
|
| 231 |
+
if isinstance(result, pd.DataFrame):
|
| 232 |
+
# Limit DataFrame output to prevent token overflow
|
| 233 |
+
if len(result) > 50:
|
| 234 |
+
result_str = f"{result.head(50).to_string()}\n\n... (showing first 50 of {len(result)} rows)"
|
| 235 |
+
else:
|
| 236 |
+
result_str = result.to_string()
|
| 237 |
+
return f"Analysis Result ({result.shape[0]} rows, {result.shape[1]} cols):\n\n{result_str}"
|
| 238 |
+
elif isinstance(result, pd.Series):
|
| 239 |
+
# Limit Series output
|
| 240 |
+
if len(result) > 50:
|
| 241 |
+
result_str = f"{result.head(50).to_string()}\n\n... (showing first 50 of {len(result)} items)"
|
| 242 |
+
else:
|
| 243 |
+
result_str = result.to_string()
|
| 244 |
+
return f"Analysis Result ({len(result)} items):\n\n{result_str}"
|
| 245 |
+
else:
|
| 246 |
+
# For other types, limit string length
|
| 247 |
+
result_str = str(result)
|
| 248 |
+
if len(result_str) > 2000:
|
| 249 |
+
result_str = result_str[:2000] + "\n\n... (truncated)"
|
| 250 |
+
return f"Analysis Result:\n\n{result_str}"
|
| 251 |
+
|
| 252 |
+
except Exception as e:
|
| 253 |
+
return f"Error executing pandas code: {str(e)}\n\nMake sure your code uses 'df' as the DataFrame variable and returns a result."
|
| 254 |
+
|
| 255 |
+
@tool("analyze_nba_data")
|
| 256 |
+
def analyze_nba_data(pandas_code: str) -> str:
|
| 257 |
+
"""
|
| 258 |
+
Execute pandas operations on NBA data for advanced analysis, aggregations, and statistical queries.
|
| 259 |
+
|
| 260 |
+
This is the PRIMARY tool for data analysis tasks like:
|
| 261 |
+
- Finding top players by statistics (groupby + aggregation + sorting)
|
| 262 |
+
- Calculating totals, averages, counts per player/team
|
| 263 |
+
- Filtering and aggregating data
|
| 264 |
+
- Statistical analysis
|
| 265 |
+
|
| 266 |
+
IMPORTANT: Use this tool for queries that require:
|
| 267 |
+
- Aggregating data (sum, mean, count, etc.)
|
| 268 |
+
- Grouping by player, team, etc.
|
| 269 |
+
- Finding top N results
|
| 270 |
+
- Calculating totals or averages
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
pandas_code: Valid pandas code that operates on a DataFrame variable named 'df'
|
| 274 |
+
The code should return a result (DataFrame, Series, or value)
|
| 275 |
+
|
| 276 |
+
Examples:
|
| 277 |
+
# Top 5 players by total 3-pointers made
|
| 278 |
+
analyze_nba_data("df.groupby('Player')['3P'].sum().sort_values(ascending=False).head(5)")
|
| 279 |
+
|
| 280 |
+
# Top 10 players by total points
|
| 281 |
+
analyze_nba_data("df.groupby('Player')['PTS'].sum().sort_values(ascending=False).head(10)")
|
| 282 |
+
|
| 283 |
+
# Players with highest 3-point percentage (minimum 100 attempts)
|
| 284 |
+
analyze_nba_data("df[df['3PA'] >= 100].groupby('Player').agg({'3P': 'sum', '3PA': 'sum'}).assign(percentage=lambda x: x['3P']/x['3PA']*100).sort_values('percentage', ascending=False).head(5)")
|
| 285 |
+
|
| 286 |
+
# Top 5 players by assists
|
| 287 |
+
analyze_nba_data("df.groupby('Player')['AST'].sum().sort_values(ascending=False).head(5)")
|
| 288 |
+
|
| 289 |
+
# Team win rates
|
| 290 |
+
analyze_nba_data("df.groupby('Tm')['Res'].apply(lambda x: (x == 'W').sum() / len(x) * 100).sort_values(ascending=False)")
|
| 291 |
+
"""
|
| 292 |
+
return _analyze_nba_data(pandas_code)
|
| 293 |
+
|
| 294 |
+
return [read_nba_data, search_nba_data, get_nba_data_summary, semantic_search_nba_data, analyze_nba_data]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vector_db.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vector database manager for NBA data using ChromaDB and sentence-transformers.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import chromadb
|
| 7 |
+
from chromadb.config import Settings
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
from typing import List, Dict, Optional
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class NBAVectorDB:
|
| 14 |
+
"""
|
| 15 |
+
Manages vector embeddings and semantic search for NBA data.
|
| 16 |
+
Uses sentence-transformers for embeddings and ChromaDB for storage.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, csv_path: str, collection_name: str = "nba_data", persist_directory: str = "./chroma_db"):
|
| 20 |
+
"""
|
| 21 |
+
Initialize the vector database.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
csv_path: Path to the NBA CSV file
|
| 25 |
+
collection_name: Name of the ChromaDB collection
|
| 26 |
+
persist_directory: Directory to persist the vector database
|
| 27 |
+
"""
|
| 28 |
+
self.csv_path = csv_path
|
| 29 |
+
self.collection_name = collection_name
|
| 30 |
+
self.persist_directory = persist_directory
|
| 31 |
+
|
| 32 |
+
# Initialize embedding model (open-source, runs locally)
|
| 33 |
+
# Using all-MiniLM-L6-v2: fast, good quality, 384 dimensions
|
| 34 |
+
print("Loading embedding model...")
|
| 35 |
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 36 |
+
print("Embedding model loaded!")
|
| 37 |
+
|
| 38 |
+
# Initialize ChromaDB client
|
| 39 |
+
os.makedirs(persist_directory, exist_ok=True)
|
| 40 |
+
self.client = chromadb.PersistentClient(
|
| 41 |
+
path=persist_directory,
|
| 42 |
+
settings=Settings(anonymized_telemetry=False)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Get or create collection
|
| 46 |
+
self.collection = self.client.get_or_create_collection(
|
| 47 |
+
name=collection_name,
|
| 48 |
+
metadata={"description": "NBA 2024-25 season data"}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Check if collection is empty and needs indexing
|
| 52 |
+
if self.collection.count() == 0:
|
| 53 |
+
print("Vector database is empty. Indexing CSV data...")
|
| 54 |
+
self._index_csv()
|
| 55 |
+
else:
|
| 56 |
+
print(f"Vector database loaded with {self.collection.count()} records")
|
| 57 |
+
|
| 58 |
+
def _create_text_representation(self, row: pd.Series) -> str:
|
| 59 |
+
"""
|
| 60 |
+
Convert a DataFrame row to a text representation for embedding.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
row: A pandas Series representing a row
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
str: Text representation of the row
|
| 67 |
+
"""
|
| 68 |
+
# Create a natural language description of the row
|
| 69 |
+
parts = []
|
| 70 |
+
|
| 71 |
+
if 'Player' in row:
|
| 72 |
+
parts.append(f"Player: {row['Player']}")
|
| 73 |
+
if 'Tm' in row:
|
| 74 |
+
parts.append(f"Team: {row['Tm']}")
|
| 75 |
+
if 'Opp' in row:
|
| 76 |
+
parts.append(f"Opponent: {row['Opp']}")
|
| 77 |
+
if 'Res' in row:
|
| 78 |
+
parts.append(f"Result: {'Win' if row['Res'] == 'W' else 'Loss'}")
|
| 79 |
+
if 'PTS' in row and pd.notna(row['PTS']):
|
| 80 |
+
parts.append(f"Points: {row['PTS']}")
|
| 81 |
+
if 'AST' in row and pd.notna(row['AST']):
|
| 82 |
+
parts.append(f"Assists: {row['AST']}")
|
| 83 |
+
if 'TRB' in row and pd.notna(row['TRB']):
|
| 84 |
+
parts.append(f"Rebounds: {row['TRB']}")
|
| 85 |
+
if 'FG%' in row and pd.notna(row['FG%']):
|
| 86 |
+
parts.append(f"Field Goal Percentage: {row['FG%']:.3f}")
|
| 87 |
+
if '3P%' in row and pd.notna(row['3P%']):
|
| 88 |
+
parts.append(f"Three Point Percentage: {row['3P%']:.3f}")
|
| 89 |
+
if 'Data' in row:
|
| 90 |
+
parts.append(f"Date: {row['Data']}")
|
| 91 |
+
|
| 92 |
+
return ". ".join(parts)
|
| 93 |
+
|
| 94 |
+
def _index_csv(self):
|
| 95 |
+
"""
|
| 96 |
+
Read CSV file, create embeddings, and store in ChromaDB.
|
| 97 |
+
"""
|
| 98 |
+
print(f"Reading CSV from {self.csv_path}...")
|
| 99 |
+
df = pd.read_csv(self.csv_path)
|
| 100 |
+
|
| 101 |
+
print(f"Creating embeddings for {len(df)} records...")
|
| 102 |
+
texts = []
|
| 103 |
+
metadatas = []
|
| 104 |
+
ids = []
|
| 105 |
+
|
| 106 |
+
# Process in batches for efficiency
|
| 107 |
+
batch_size = 100
|
| 108 |
+
total_batches = (len(df) + batch_size - 1) // batch_size
|
| 109 |
+
|
| 110 |
+
for batch_idx in range(0, len(df), batch_size):
|
| 111 |
+
batch_df = df.iloc[batch_idx:batch_idx + batch_size]
|
| 112 |
+
batch_num = (batch_idx // batch_size) + 1
|
| 113 |
+
|
| 114 |
+
batch_texts = []
|
| 115 |
+
batch_metadatas = []
|
| 116 |
+
batch_ids = []
|
| 117 |
+
|
| 118 |
+
for idx, row in batch_df.iterrows():
|
| 119 |
+
# Create text representation
|
| 120 |
+
text = self._create_text_representation(row)
|
| 121 |
+
batch_texts.append(text)
|
| 122 |
+
|
| 123 |
+
# Store metadata (original row data as JSON)
|
| 124 |
+
metadata = {
|
| 125 |
+
'row_index': int(idx),
|
| 126 |
+
'player': str(row.get('Player', '')),
|
| 127 |
+
'team': str(row.get('Tm', '')),
|
| 128 |
+
'opponent': str(row.get('Opp', '')),
|
| 129 |
+
'result': str(row.get('Res', '')),
|
| 130 |
+
'points': float(row.get('PTS', 0)) if pd.notna(row.get('PTS')) else 0.0,
|
| 131 |
+
'date': str(row.get('Data', '')),
|
| 132 |
+
}
|
| 133 |
+
batch_metadatas.append(metadata)
|
| 134 |
+
batch_ids.append(f"row_{idx}")
|
| 135 |
+
|
| 136 |
+
# Generate embeddings for this batch
|
| 137 |
+
print(f"Processing batch {batch_num}/{total_batches} ({len(batch_texts)} records)...")
|
| 138 |
+
embeddings = self.embedding_model.encode(
|
| 139 |
+
batch_texts,
|
| 140 |
+
show_progress_bar=False,
|
| 141 |
+
convert_to_numpy=True
|
| 142 |
+
).tolist()
|
| 143 |
+
|
| 144 |
+
# Add to ChromaDB
|
| 145 |
+
self.collection.add(
|
| 146 |
+
embeddings=embeddings,
|
| 147 |
+
documents=batch_texts,
|
| 148 |
+
metadatas=batch_metadatas,
|
| 149 |
+
ids=batch_ids
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
texts.extend(batch_texts)
|
| 153 |
+
metadatas.extend(batch_metadatas)
|
| 154 |
+
ids.extend(batch_ids)
|
| 155 |
+
|
| 156 |
+
print(f"Successfully indexed {len(df)} records in vector database!")
|
| 157 |
+
|
| 158 |
+
def search(self, query: str, n_results: int = 10) -> List[Dict]:
|
| 159 |
+
"""
|
| 160 |
+
Perform semantic search on the NBA data.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
query: Natural language query
|
| 164 |
+
n_results: Number of results to return
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
List of dictionaries containing search results with metadata
|
| 168 |
+
"""
|
| 169 |
+
# Generate embedding for the query
|
| 170 |
+
query_embedding = self.embedding_model.encode(
|
| 171 |
+
query,
|
| 172 |
+
convert_to_numpy=True
|
| 173 |
+
).tolist()
|
| 174 |
+
|
| 175 |
+
# Search in ChromaDB
|
| 176 |
+
results = self.collection.query(
|
| 177 |
+
query_embeddings=[query_embedding],
|
| 178 |
+
n_results=n_results,
|
| 179 |
+
include=['documents', 'metadatas', 'distances']
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Format results
|
| 183 |
+
formatted_results = []
|
| 184 |
+
if results['ids'] and len(results['ids'][0]) > 0:
|
| 185 |
+
for i in range(len(results['ids'][0])):
|
| 186 |
+
formatted_results.append({
|
| 187 |
+
'id': results['ids'][0][i],
|
| 188 |
+
'document': results['documents'][0][i],
|
| 189 |
+
'metadata': results['metadatas'][0][i],
|
| 190 |
+
'distance': results['distances'][0][i],
|
| 191 |
+
'similarity': 1 - results['distances'][0][i] # Convert distance to similarity
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
return formatted_results
|
| 195 |
+
|
| 196 |
+
def get_original_row(self, row_index: int) -> Optional[pd.Series]:
|
| 197 |
+
"""
|
| 198 |
+
Retrieve the original CSV row by index.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
row_index: Index of the row in the original CSV
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
pandas Series or None if not found
|
| 205 |
+
"""
|
| 206 |
+
try:
|
| 207 |
+
df = pd.read_csv(self.csv_path)
|
| 208 |
+
if 0 <= row_index < len(df):
|
| 209 |
+
return df.iloc[row_index]
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Error retrieving row {row_index}: {e}")
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Global instance (will be initialized when needed)
|
| 216 |
+
_vector_db_instance: Optional[NBAVectorDB] = None
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_vector_db(csv_path: str) -> NBAVectorDB:
|
| 220 |
+
"""
|
| 221 |
+
Get or create the global vector database instance.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
csv_path: Path to the CSV file
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
NBAVectorDB instance
|
| 228 |
+
"""
|
| 229 |
+
global _vector_db_instance
|
| 230 |
+
if _vector_db_instance is None or _vector_db_instance.csv_path != csv_path:
|
| 231 |
+
_vector_db_instance = NBAVectorDB(csv_path)
|
| 232 |
+
return _vector_db_instance
|
| 233 |
+
|