Update app.py
Browse files
app.py
CHANGED
|
@@ -1,22 +1,28 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
import torch
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
|
| 4 |
-
# from langchain.agents import create_pandas_dataframe_agent
|
| 5 |
from langchain_experimental.agents import create_pandas_dataframe_agent
|
| 6 |
from langchain_community.llms import HuggingFacePipeline
|
| 7 |
from langchain_core.messages import SystemMessage
|
| 8 |
import gradio as gr
|
| 9 |
-
import os
|
| 10 |
|
| 11 |
# --- Configuration ---
|
| 12 |
-
# You might want to make these environment variables in a real deployment
|
| 13 |
-
# but for a basic Space, hardcoding is fine for small models.
|
| 14 |
LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 15 |
DATA_FILE_PATH = "IPL.csv"
|
| 16 |
|
| 17 |
-
# ---
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
try:
|
| 21 |
df = pd.read_csv(DATA_FILE_PATH)
|
| 22 |
print("IPL.csv loaded successfully.")
|
|
@@ -27,19 +33,25 @@ def load_and_prepare_data():
|
|
| 27 |
df['date'] = pd.to_datetime(df['date'], errors='coerce')
|
| 28 |
df['total_runs_this_ball'] = df['runs_off_bat'] + df['extras_run']
|
| 29 |
print("DataFrame prepared.")
|
| 30 |
-
|
|
|
|
| 31 |
except FileNotFoundError:
|
|
|
|
|
|
|
| 32 |
return None
|
| 33 |
except Exception as e:
|
| 34 |
-
|
|
|
|
| 35 |
return None
|
| 36 |
|
| 37 |
-
# ---
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
if df is None:
|
| 42 |
-
return None,
|
| 43 |
|
| 44 |
bnb_config = BitsAndBytesConfig(
|
| 45 |
load_in_4bit=True,
|
|
@@ -76,10 +88,9 @@ def load_llm_and_agent(df):
|
|
| 76 |
llm = HuggingFacePipeline(pipeline=llm_pipeline)
|
| 77 |
print("LLM loaded and configured.")
|
| 78 |
|
| 79 |
-
|
| 80 |
-
system_message_content = f"""
|
| 81 |
You are an expert cricket analyst. You have access to a pandas DataFrame named `df` containing ball-by-ball IPL match data.
|
| 82 |
-
The DataFrame has columns
|
| 83 |
Your goal is to answer user questions about IPL cricket statistics by writing and executing pandas code.
|
| 84 |
When performing calculations, be precise. For averages, ensure you handle division by zero.
|
| 85 |
If the answer is a numerical value, just output the number. If it's a specific player or team name, output just the name.
|
|
@@ -91,73 +102,82 @@ def load_llm_and_agent(df):
|
|
| 91 |
agent = create_pandas_dataframe_agent(
|
| 92 |
llm,
|
| 93 |
df,
|
| 94 |
-
verbose=True,
|
| 95 |
agent_executor_kwargs={"handle_parsing_errors": True, "max_iterations": 10},
|
| 96 |
agent_type="openai-tools",
|
| 97 |
-
# Pass system message as part of agent creation if supported or through the prompt template
|
| 98 |
-
# Note: Depending on LangChain version and agent type, directly injecting system_message might vary.
|
| 99 |
-
# This structure is generally accepted.
|
| 100 |
-
# For agent_type="openai-tools", the system message is typically passed to the LLM directly by the agent executor.
|
| 101 |
)
|
| 102 |
print("Pandas DataFrame Agent created.")
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
-
# --- Gradio Interface Function ---
|
| 106 |
def predict_answer(question):
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
if
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
if agent_instance is None:
|
| 115 |
-
return error
|
| 116 |
|
| 117 |
try:
|
| 118 |
-
response =
|
| 119 |
return response['output']
|
| 120 |
except Exception as e:
|
|
|
|
| 121 |
return f"An error occurred while processing your request: {e}\nPlease try rephrasing your question or check the Space logs for more details."
|
| 122 |
|
| 123 |
-
# --- Initial setup
|
| 124 |
-
# This
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
# --- Gradio UI ---
|
| 129 |
-
if
|
|
|
|
| 130 |
with gr.Blocks() as demo:
|
| 131 |
-
gr.Markdown("# IPL Cricket Data Agent (Error)")
|
| 132 |
-
gr.Markdown(f"###
|
| 133 |
-
gr.Markdown("
|
|
|
|
| 134 |
else:
|
| 135 |
with gr.Blocks() as demo:
|
| 136 |
gr.Markdown("# IPL Cricket Data Agent")
|
| 137 |
gr.Markdown(
|
| 138 |
"Ask me anything about the IPL dataset! "
|
| 139 |
-
"For example: '
|
|
|
|
| 140 |
"'List the top 5 batsmen by total runs scored across all seasons.', "
|
| 141 |
-
"'
|
|
|
|
| 142 |
)
|
| 143 |
|
| 144 |
chatbot = gr.Chatbot(label="Cricket Analyst")
|
| 145 |
-
msg = gr.Textbox(label="Your Question")
|
| 146 |
clear = gr.Button("Clear")
|
| 147 |
|
| 148 |
def user_message(user_message, history):
|
|
|
|
| 149 |
history = history + [[user_message, None]]
|
| 150 |
return "", history
|
| 151 |
|
| 152 |
def bot_response(history):
|
|
|
|
| 153 |
query = history[-1][0]
|
| 154 |
response = predict_answer(query)
|
| 155 |
-
history[-1][1] = response
|
| 156 |
return history
|
| 157 |
|
|
|
|
| 158 |
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 159 |
bot_response, chatbot, chatbot
|
| 160 |
)
|
| 161 |
-
clear.click(lambda:
|
| 162 |
|
|
|
|
| 163 |
demo.queue().launch(debug=True)
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import torch
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
|
|
|
|
| 4 |
from langchain_experimental.agents import create_pandas_dataframe_agent
|
| 5 |
from langchain_community.llms import HuggingFacePipeline
|
| 6 |
from langchain_core.messages import SystemMessage
|
| 7 |
import gradio as gr
|
| 8 |
+
import os
|
| 9 |
|
| 10 |
# --- Configuration ---
|
|
|
|
|
|
|
| 11 |
LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 12 |
DATA_FILE_PATH = "IPL.csv"
|
| 13 |
|
| 14 |
+
# --- Global variables for manual caching ---
|
| 15 |
+
# These will hold our loaded DataFrame and agent instance
|
| 16 |
+
_df_cache = None
|
| 17 |
+
_agent_cache = None
|
| 18 |
+
_load_error_cache = None # To store any error during initial load
|
| 19 |
+
|
| 20 |
+
# --- Function to load and prepare the DataFrame (will run once) ---
|
| 21 |
+
def load_and_prepare_data_singleton():
|
| 22 |
+
global _df_cache, _load_error_cache
|
| 23 |
+
if _df_cache is not None:
|
| 24 |
+
return _df_cache # Return cached DataFrame if already loaded
|
| 25 |
+
|
| 26 |
try:
|
| 27 |
df = pd.read_csv(DATA_FILE_PATH)
|
| 28 |
print("IPL.csv loaded successfully.")
|
|
|
|
| 33 |
df['date'] = pd.to_datetime(df['date'], errors='coerce')
|
| 34 |
df['total_runs_this_ball'] = df['runs_off_bat'] + df['extras_run']
|
| 35 |
print("DataFrame prepared.")
|
| 36 |
+
_df_cache = df # Cache the loaded DataFrame for future use
|
| 37 |
+
return _df_cache
|
| 38 |
except FileNotFoundError:
|
| 39 |
+
_load_error_cache = "Error: IPL.csv not found. Make sure it's in the Space."
|
| 40 |
+
print(_load_error_cache)
|
| 41 |
return None
|
| 42 |
except Exception as e:
|
| 43 |
+
_load_error_cache = f"Error loading or preparing data: {e}"
|
| 44 |
+
print(_load_error_cache)
|
| 45 |
return None
|
| 46 |
|
| 47 |
+
# --- Function to load LLM and create Agent (will run once) ---
|
| 48 |
+
def load_llm_and_agent_singleton(df):
|
| 49 |
+
global _agent_cache, _load_error_cache
|
| 50 |
+
if _agent_cache is not None:
|
| 51 |
+
return _agent_cache, None # Return cached agent if already loaded
|
| 52 |
+
|
| 53 |
if df is None:
|
| 54 |
+
return None, _load_error_cache # Propagate error if DataFrame failed to load
|
| 55 |
|
| 56 |
bnb_config = BitsAndBytesConfig(
|
| 57 |
load_in_4bit=True,
|
|
|
|
| 88 |
llm = HuggingFacePipeline(pipeline=llm_pipeline)
|
| 89 |
print("LLM loaded and configured.")
|
| 90 |
|
| 91 |
+
system_message_content = """
|
|
|
|
| 92 |
You are an expert cricket analyst. You have access to a pandas DataFrame named `df` containing ball-by-ball IPL match data.
|
| 93 |
+
The DataFrame has columns like 'id', 'inning', 'overs', 'ballnumber', 'batsman', 'non_striker', 'bowler', 'runs_off_bat', 'extras_run', 'total_runs_this_ball', 'iswicketdelivery', 'player_out', 'kind', 'fielders_involved', 'bowlingteam', 'battingteam', 'striker', 'nonstriker', 'extra_type', 'byes_run', 'legbyes_run', 'noball_run', 'penalty_run', 'out_type', 'matchid', 'team1', 'team2', 'venue', 'date', 'winningteam', 'player_of_match', 'season'.
|
| 94 |
Your goal is to answer user questions about IPL cricket statistics by writing and executing pandas code.
|
| 95 |
When performing calculations, be precise. For averages, ensure you handle division by zero.
|
| 96 |
If the answer is a numerical value, just output the number. If it's a specific player or team name, output just the name.
|
|
|
|
| 102 |
agent = create_pandas_dataframe_agent(
|
| 103 |
llm,
|
| 104 |
df,
|
| 105 |
+
verbose=True,
|
| 106 |
agent_executor_kwargs={"handle_parsing_errors": True, "max_iterations": 10},
|
| 107 |
agent_type="openai-tools",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
)
|
| 109 |
print("Pandas DataFrame Agent created.")
|
| 110 |
+
_agent_cache = agent # Cache the agent instance for future use
|
| 111 |
+
return _agent_cache, None
|
| 112 |
|
| 113 |
+
# --- Gradio Interface Function - this is what the UI calls ---
|
| 114 |
def predict_answer(question):
|
| 115 |
+
# This ensures loading happens only once on app startup (or first request)
|
| 116 |
+
# The global variables _df_cache, _agent_cache, _load_error_cache
|
| 117 |
+
# are populated by the code running outside this function on script startup.
|
| 118 |
|
| 119 |
+
if _load_error_cache: # If there was an error during initial setup
|
| 120 |
+
return _load_error_cache
|
| 121 |
+
|
| 122 |
+
if _agent_cache is None or _df_cache is None: # Should not happen if initial setup worked
|
| 123 |
+
return "Internal error: Model or data not loaded. Please check logs."
|
|
|
|
|
|
|
| 124 |
|
| 125 |
try:
|
| 126 |
+
response = _agent_cache.invoke({"input": question})
|
| 127 |
return response['output']
|
| 128 |
except Exception as e:
|
| 129 |
+
# Log the full traceback if possible in a production setting
|
| 130 |
return f"An error occurred while processing your request: {e}\nPlease try rephrasing your question or check the Space logs for more details."
|
| 131 |
|
| 132 |
+
# --- Initial setup - These lines run ONCE when the app.py script starts ---
|
| 133 |
+
# This is where the heavy loading happens.
|
| 134 |
+
print("Starting initial setup: Loading data and model...")
|
| 135 |
+
_df_cache = load_and_prepare_data_singleton()
|
| 136 |
+
_agent_cache, _load_error_cache = load_llm_and_agent_singleton(_df_cache)
|
| 137 |
+
print("Initial setup complete.")
|
| 138 |
+
|
| 139 |
|
| 140 |
# --- Gradio UI ---
|
| 141 |
+
# Check if initial loading encountered an error before launching the UI
|
| 142 |
+
if _load_error_cache:
|
| 143 |
with gr.Blocks() as demo:
|
| 144 |
+
gr.Markdown("# IPL Cricket Data Agent (Initialization Error)")
|
| 145 |
+
gr.Markdown(f"### An error occurred during startup:")
|
| 146 |
+
gr.Markdown(f"```{_load_error_cache}```")
|
| 147 |
+
gr.Markdown("Please check the Space logs for more details and ensure `IPL.csv` is correctly uploaded.")
|
| 148 |
else:
|
| 149 |
with gr.Blocks() as demo:
|
| 150 |
gr.Markdown("# IPL Cricket Data Agent")
|
| 151 |
gr.Markdown(
|
| 152 |
"Ask me anything about the IPL dataset! "
|
| 153 |
+
"For example: 'How many matches are in the dataset?', "
|
| 154 |
+
"'Who won the match between MI and CSK in 2023 on 2023-05-18?', "
|
| 155 |
"'List the top 5 batsmen by total runs scored across all seasons.', "
|
| 156 |
+
"'Which bowler has taken the most wickets in the 2024 season?', "
|
| 157 |
+
"'What is the average number of runs scored per over in the 2023 season?'"
|
| 158 |
)
|
| 159 |
|
| 160 |
chatbot = gr.Chatbot(label="Cricket Analyst")
|
| 161 |
+
msg = gr.Textbox(label="Your Question", placeholder="Type your question here...")
|
| 162 |
clear = gr.Button("Clear")
|
| 163 |
|
| 164 |
def user_message(user_message, history):
|
| 165 |
+
# Append user message immediately for responsiveness
|
| 166 |
history = history + [[user_message, None]]
|
| 167 |
return "", history
|
| 168 |
|
| 169 |
def bot_response(history):
|
| 170 |
+
# Get the last user message and call predict_answer
|
| 171 |
query = history[-1][0]
|
| 172 |
response = predict_answer(query)
|
| 173 |
+
history[-1][1] = response # Update the bot's response
|
| 174 |
return history
|
| 175 |
|
| 176 |
+
# Event listeners for Gradio components
|
| 177 |
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 178 |
bot_response, chatbot, chatbot
|
| 179 |
)
|
| 180 |
+
clear.click(lambda: [], None, chatbot, queue=False) # Clear history
|
| 181 |
|
| 182 |
+
# Launch the Gradio app
|
| 183 |
demo.queue().launch(debug=True)
|