jay0911 commited on
Commit
67d7ab8
·
verified ·
1 Parent(s): d2ddebe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -46
app.py CHANGED
@@ -1,22 +1,28 @@
1
  import pandas as pd
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
4
- # from langchain.agents import create_pandas_dataframe_agent
5
  from langchain_experimental.agents import create_pandas_dataframe_agent
6
  from langchain_community.llms import HuggingFacePipeline
7
  from langchain_core.messages import SystemMessage
8
  import gradio as gr
9
- import os # For checking file existence
10
 
11
  # --- Configuration ---
12
- # You might want to make these environment variables in a real deployment
13
- # but for a basic Space, hardcoding is fine for small models.
14
  LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
15
  DATA_FILE_PATH = "IPL.csv"
16
 
17
- # --- 1. Load and Prepare the DataFrame ---
18
- @gr.cache
19
- def load_and_prepare_data():
 
 
 
 
 
 
 
 
 
20
  try:
21
  df = pd.read_csv(DATA_FILE_PATH)
22
  print("IPL.csv loaded successfully.")
@@ -27,19 +33,25 @@ def load_and_prepare_data():
27
  df['date'] = pd.to_datetime(df['date'], errors='coerce')
28
  df['total_runs_this_ball'] = df['runs_off_bat'] + df['extras_run']
29
  print("DataFrame prepared.")
30
- return df
 
31
  except FileNotFoundError:
 
 
32
  return None
33
  except Exception as e:
34
- print(f"Error loading or preparing data: {e}")
 
35
  return None
36
 
37
- # --- 2. Initialize the Code-Generating LLM ---
38
- # This function will be called once when the Gradio app starts
39
- @gr.cache
40
- def load_llm_and_agent(df):
 
 
41
  if df is None:
42
- return None, "Error: Data not loaded."
43
 
44
  bnb_config = BitsAndBytesConfig(
45
  load_in_4bit=True,
@@ -76,10 +88,9 @@ def load_llm_and_agent(df):
76
  llm = HuggingFacePipeline(pipeline=llm_pipeline)
77
  print("LLM loaded and configured.")
78
 
79
- # --- 3. Create the Pandas DataFrame Agent ---
80
- system_message_content = f"""
81
  You are an expert cricket analyst. You have access to a pandas DataFrame named `df` containing ball-by-ball IPL match data.
82
- The DataFrame has columns : {df.columns}.
83
  Your goal is to answer user questions about IPL cricket statistics by writing and executing pandas code.
84
  When performing calculations, be precise. For averages, ensure you handle division by zero.
85
  If the answer is a numerical value, just output the number. If it's a specific player or team name, output just the name.
@@ -91,73 +102,82 @@ def load_llm_and_agent(df):
91
  agent = create_pandas_dataframe_agent(
92
  llm,
93
  df,
94
- verbose=True, # Will print agent's thought process to the Space logs
95
  agent_executor_kwargs={"handle_parsing_errors": True, "max_iterations": 10},
96
  agent_type="openai-tools",
97
- # Pass system message as part of agent creation if supported or through the prompt template
98
- # Note: Depending on LangChain version and agent type, directly injecting system_message might vary.
99
- # This structure is generally accepted.
100
- # For agent_type="openai-tools", the system message is typically passed to the LLM directly by the agent executor.
101
  )
102
  print("Pandas DataFrame Agent created.")
103
- return agent, None
 
104
 
105
- # --- Gradio Interface Function ---
106
  def predict_answer(question):
107
- global agent_instance # Use global to avoid re-loading agent on every call
 
 
108
 
109
- if agent_instance is None:
110
- df_data = load_and_prepare_data()
111
- if df_data is None:
112
- return "Error: Could not load IPL data. Please check logs."
113
- agent_instance, error = load_llm_and_agent(df_data)
114
- if agent_instance is None:
115
- return error
116
 
117
  try:
118
- response = agent_instance.invoke({"input": question})
119
  return response['output']
120
  except Exception as e:
 
121
  return f"An error occurred while processing your request: {e}\nPlease try rephrasing your question or check the Space logs for more details."
122
 
123
- # --- Initial setup for the global agent_instance ---
124
- # This part runs when the script starts (once on Space boot-up)
125
- df_global = load_and_prepare_data()
126
- agent_instance, initial_load_error = load_llm_and_agent(df_global)
 
 
 
127
 
128
  # --- Gradio UI ---
129
- if initial_load_error:
 
130
  with gr.Blocks() as demo:
131
- gr.Markdown("# IPL Cricket Data Agent (Error)")
132
- gr.Markdown(f"### Initialization Error: {initial_load_error}")
133
- gr.Markdown("Please check the Space logs for more details.")
 
134
  else:
135
  with gr.Blocks() as demo:
136
  gr.Markdown("# IPL Cricket Data Agent")
137
  gr.Markdown(
138
  "Ask me anything about the IPL dataset! "
139
- "For example: 'Who won the match between MI and CSK in 2023 on 2023-05-18?', "
 
140
  "'List the top 5 batsmen by total runs scored across all seasons.', "
141
- "'What is the total number of no-balls bowled in the entire dataset?'"
 
142
  )
143
 
144
  chatbot = gr.Chatbot(label="Cricket Analyst")
145
- msg = gr.Textbox(label="Your Question")
146
  clear = gr.Button("Clear")
147
 
148
  def user_message(user_message, history):
 
149
  history = history + [[user_message, None]]
150
  return "", history
151
 
152
  def bot_response(history):
 
153
  query = history[-1][0]
154
  response = predict_answer(query)
155
- history[-1][1] = response
156
  return history
157
 
 
158
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
159
  bot_response, chatbot, chatbot
160
  )
161
- clear.click(lambda: None, None, chatbot, queue=False)
162
 
 
163
  demo.queue().launch(debug=True)
 
1
  import pandas as pd
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
 
4
  from langchain_experimental.agents import create_pandas_dataframe_agent
5
  from langchain_community.llms import HuggingFacePipeline
6
  from langchain_core.messages import SystemMessage
7
  import gradio as gr
8
+ import os
9
 
10
  # --- Configuration ---
 
 
11
  LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
12
  DATA_FILE_PATH = "IPL.csv"
13
 
14
+ # --- Global variables for manual caching ---
15
+ # These will hold our loaded DataFrame and agent instance
16
+ _df_cache = None
17
+ _agent_cache = None
18
+ _load_error_cache = None # To store any error during initial load
19
+
20
+ # --- Function to load and prepare the DataFrame (will run once) ---
21
+ def load_and_prepare_data_singleton():
22
+ global _df_cache, _load_error_cache
23
+ if _df_cache is not None:
24
+ return _df_cache # Return cached DataFrame if already loaded
25
+
26
  try:
27
  df = pd.read_csv(DATA_FILE_PATH)
28
  print("IPL.csv loaded successfully.")
 
33
  df['date'] = pd.to_datetime(df['date'], errors='coerce')
34
  df['total_runs_this_ball'] = df['runs_off_bat'] + df['extras_run']
35
  print("DataFrame prepared.")
36
+ _df_cache = df # Cache the loaded DataFrame for future use
37
+ return _df_cache
38
  except FileNotFoundError:
39
+ _load_error_cache = "Error: IPL.csv not found. Make sure it's in the Space."
40
+ print(_load_error_cache)
41
  return None
42
  except Exception as e:
43
+ _load_error_cache = f"Error loading or preparing data: {e}"
44
+ print(_load_error_cache)
45
  return None
46
 
47
+ # --- Function to load LLM and create Agent (will run once) ---
48
+ def load_llm_and_agent_singleton(df):
49
+ global _agent_cache, _load_error_cache
50
+ if _agent_cache is not None:
51
+ return _agent_cache, None # Return cached agent if already loaded
52
+
53
  if df is None:
54
+ return None, _load_error_cache # Propagate error if DataFrame failed to load
55
 
56
  bnb_config = BitsAndBytesConfig(
57
  load_in_4bit=True,
 
88
  llm = HuggingFacePipeline(pipeline=llm_pipeline)
89
  print("LLM loaded and configured.")
90
 
91
+ system_message_content = """
 
92
  You are an expert cricket analyst. You have access to a pandas DataFrame named `df` containing ball-by-ball IPL match data.
93
+ The DataFrame has columns like 'id', 'inning', 'overs', 'ballnumber', 'batsman', 'non_striker', 'bowler', 'runs_off_bat', 'extras_run', 'total_runs_this_ball', 'iswicketdelivery', 'player_out', 'kind', 'fielders_involved', 'bowlingteam', 'battingteam', 'striker', 'nonstriker', 'extra_type', 'byes_run', 'legbyes_run', 'noball_run', 'penalty_run', 'out_type', 'matchid', 'team1', 'team2', 'venue', 'date', 'winningteam', 'player_of_match', 'season'.
94
  Your goal is to answer user questions about IPL cricket statistics by writing and executing pandas code.
95
  When performing calculations, be precise. For averages, ensure you handle division by zero.
96
  If the answer is a numerical value, just output the number. If it's a specific player or team name, output just the name.
 
102
  agent = create_pandas_dataframe_agent(
103
  llm,
104
  df,
105
+ verbose=True,
106
  agent_executor_kwargs={"handle_parsing_errors": True, "max_iterations": 10},
107
  agent_type="openai-tools",
 
 
 
 
108
  )
109
  print("Pandas DataFrame Agent created.")
110
+ _agent_cache = agent # Cache the agent instance for future use
111
+ return _agent_cache, None
112
 
113
+ # --- Gradio Interface Function - this is what the UI calls ---
114
  def predict_answer(question):
115
+ # This ensures loading happens only once on app startup (or first request)
116
+ # The global variables _df_cache, _agent_cache, _load_error_cache
117
+ # are populated by the code running outside this function on script startup.
118
 
119
+ if _load_error_cache: # If there was an error during initial setup
120
+ return _load_error_cache
121
+
122
+ if _agent_cache is None or _df_cache is None: # Should not happen if initial setup worked
123
+ return "Internal error: Model or data not loaded. Please check logs."
 
 
124
 
125
  try:
126
+ response = _agent_cache.invoke({"input": question})
127
  return response['output']
128
  except Exception as e:
129
+ # Log the full traceback if possible in a production setting
130
  return f"An error occurred while processing your request: {e}\nPlease try rephrasing your question or check the Space logs for more details."
131
 
132
+ # --- Initial setup - These lines run ONCE when the app.py script starts ---
133
+ # This is where the heavy loading happens.
134
+ print("Starting initial setup: Loading data and model...")
135
+ _df_cache = load_and_prepare_data_singleton()
136
+ _agent_cache, _load_error_cache = load_llm_and_agent_singleton(_df_cache)
137
+ print("Initial setup complete.")
138
+
139
 
140
  # --- Gradio UI ---
141
+ # Check if initial loading encountered an error before launching the UI
142
+ if _load_error_cache:
143
  with gr.Blocks() as demo:
144
+ gr.Markdown("# IPL Cricket Data Agent (Initialization Error)")
145
+ gr.Markdown(f"### An error occurred during startup:")
146
+ gr.Markdown(f"```{_load_error_cache}```")
147
+ gr.Markdown("Please check the Space logs for more details and ensure `IPL.csv` is correctly uploaded.")
148
  else:
149
  with gr.Blocks() as demo:
150
  gr.Markdown("# IPL Cricket Data Agent")
151
  gr.Markdown(
152
  "Ask me anything about the IPL dataset! "
153
+ "For example: 'How many matches are in the dataset?', "
154
+ "'Who won the match between MI and CSK in 2023 on 2023-05-18?', "
155
  "'List the top 5 batsmen by total runs scored across all seasons.', "
156
+ "'Which bowler has taken the most wickets in the 2024 season?', "
157
+ "'What is the average number of runs scored per over in the 2023 season?'"
158
  )
159
 
160
  chatbot = gr.Chatbot(label="Cricket Analyst")
161
+ msg = gr.Textbox(label="Your Question", placeholder="Type your question here...")
162
  clear = gr.Button("Clear")
163
 
164
  def user_message(user_message, history):
165
+ # Append user message immediately for responsiveness
166
  history = history + [[user_message, None]]
167
  return "", history
168
 
169
  def bot_response(history):
170
+ # Get the last user message and call predict_answer
171
  query = history[-1][0]
172
  response = predict_answer(query)
173
+ history[-1][1] = response # Update the bot's response
174
  return history
175
 
176
+ # Event listeners for Gradio components
177
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
178
  bot_response, chatbot, chatbot
179
  )
180
+ clear.click(lambda: [], None, chatbot, queue=False) # Clear history
181
 
182
+ # Launch the Gradio app
183
  demo.queue().launch(debug=True)