jay0911 commited on
Commit
6c5a9f9
·
verified ·
1 Parent(s): 9506be9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -119
app.py CHANGED
@@ -14,13 +14,13 @@ DATA_FILE_PATH = "IPL.csv"
14
 
15
  # --- Global variable for DataFrame only (can be cached globally as it's simple) ---
16
  _df_cache = None
17
- _load_error_cache = None # Store error during initial data loading
18
 
19
  # --- Function to load and prepare the DataFrame (will run once) ---
20
  def load_and_prepare_data_singleton():
21
  global _df_cache, _load_error_cache
22
  if _df_cache is not None:
23
- return _df_cache # Return cached DataFrame if already loaded
24
  try:
25
  df = pd.read_csv(DATA_FILE_PATH, low_memory=False)
26
  print("IPL.csv loaded successfully.")
@@ -56,131 +56,69 @@ def load_and_prepare_data_singleton():
56
  print(_load_error_cache)
57
  return None
58
 
59
- # --- NEW: Function to load LLM and create Agent (per request, decorated with @spaces.GPU) ---
60
- # This function is now responsible for loading the LLM and creating the agent
61
- # within the GPU worker process for each prediction.
62
- @spaces.GPU # <--- Apply @spaces.GPU here
63
- def get_llm_and_agent(df):
64
- if df is None:
65
- raise ValueError("DataFrame not loaded, cannot create agent.")
66
-
67
- # These checks are now within the GPU-allocated context
68
- if not torch.cuda.is_available():
69
- raise RuntimeError("Error: CUDA (GPU) is not available. This model requires a GPU.")
70
- print(f"CUDA available: {torch.cuda.is_available()}")
71
- print(f"CUDA device count: {torch.cuda.device_count()}")
72
- print(f"Current CUDA device: {torch.cuda.current_device()}")
73
-
74
- bnb_config = BitsAndBytesConfig(
75
- load_in_4bit=True,
76
- bnb_4bit_quant_type="nf4",
77
- bnb_4bit_compute_dtype=torch.float16,
78
- bnb_4bit_use_double_quant=False,
79
- )
80
- print(f"Loading LLM: {LLM_MODEL_ID}...")
81
- llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
82
- if llm_tokenizer.pad_token is None:
83
- llm_tokenizer.pad_token = llm_tokenizer.eos_token
84
- llm_model = AutoModelForCausalLM.from_pretrained(
85
- LLM_MODEL_ID,
86
- quantization_config=bnb_config,
87
- torch_dtype=torch.float16,
88
- device_map="auto",
89
- trust_remote_code=True,
90
- )
91
- llm_pipeline = pipeline(
92
- "text-generation",
93
- model=llm_model,
94
- tokenizer=llm_tokenizer,
95
- max_new_tokens=1000,
96
- do_sample=True,
97
- temperature=0.1,
98
- top_p=0.9,
99
- eos_token_id=llm_tokenizer.eos_token_id,
100
- pad_token_id=llm_tokenizer.pad_token_id,
101
- )
102
- llm = HuggingFacePipeline(pipeline=llm_pipeline)
103
- print("LLM loaded and configured.")
104
-
105
- system_message_content = """
106
- You are an expert cricket analyst. You have access to a pandas DataFrame named `df` containing ball-by-ball IPL match data.
107
- The DataFrame has the following relevant columns for querying:
108
- - 'match_id': Unique ID for each match.
109
- - 'date': Date of the match (datetime object).
110
- - 'match_type': Type of match (e.g., T20).
111
- - 'event_name': Name of the event (e.g., Indian Premier League).
112
- - 'innings': The innings number (1 or 2).
113
- - 'batting_team': The team currently batting.
114
- - 'bowling_team': The team currently bowling.
115
- - 'over', 'ball', 'ball_no': Details about the specific ball.
116
- - 'batter': The batsman on strike.
117
- - 'bat_pos': Batting position.
118
- - 'runs_batter': Runs scored by the batsman on that ball (off the bat).
119
- - 'balls_faced': Balls faced by the batter up to that point in the innings.
120
- - 'bowler': The bowler who bowled that ball.
121
- - 'valid_ball': Whether the ball was a valid delivery.
122
- - 'runs_extras': Runs scored as extras (wides, no-balls, byes, leg-byes, penalty).
123
- - 'runs_total': Total runs scored on that ball (runs_batter + runs_extras).
124
- - 'runs_bowler': Runs conceded by the bowler on that ball.
125
- - 'extra_type': Type of extra (e.g., 'wides', 'noball').
126
- - 'non_striker': The non_striker batsman.
127
- - 'wicket_kind': Type of dismissal (e.g., 'bowled', 'caught').
128
- - 'player_out': The player dismissed.
129
- - 'fielders': Fielders involved in the dismissal.
130
- - 'player_of_match': Player of the match.
131
- - 'match_won_by': The team that won the match.
132
- - 'win_outcome': How the match was won (e.g., 'runs', 'wickets').
133
- - 'toss_winner': The team that won the toss.
134
- - 'toss_decision': What the toss winner decided to do (bat or bowl).
135
- - 'venue': Match venue.
136
- - 'city': City where the match was played.
137
- - 'year', 'season': Year and IPL season.
138
- - 'gender', 'team_type', 'superover_winner', 'result_type', 'method': Other match details.
139
- - 'team_runs', 'team_balls', 'team_wicket': Team's total runs, balls, wickets.
140
- - 'new_batter', 'batter_runs', 'batter_balls', 'bowler_wicket': Aggregated stats.
141
- - 'batting_partners', 'next_batter', 'striker_out': More granular details.
142
- - 'total_runs_this_ball': (NEW COLUMN YOU ADDED) Sum of 'runs_batter' and 'runs_extras' for that specific ball.
143
- Your goal is to answer user questions about IPL cricket statistics by writing and executing pandas code on the `df` DataFrame.
144
- When performing calculations, be precise. For averages, ensure you handle division by zero (e.g., by checking if denominator is zero or using `df.sum() / df.count()` for means).
145
- If the answer is a numerical value, just output the number. If it's a specific player or team name, output just the name.
146
- If you cannot find the answer in the DataFrame, state that you don't know or that the information is not available.
147
- Avoid providing general cricket knowledge not derivable from the DataFrame.
148
- Focus solely on extracting information from the 'df' DataFrame. When answering questions about totals or aggregations, consider all relevant rows unless a specific filter (like season or match) is provided.
149
- Always try to provide a concise answer directly from the data.
150
- """
151
- agent = create_pandas_dataframe_agent(
152
- llm,
153
- df,
154
- verbose=True,
155
- max_iterations=10,
156
- handle_parsing_errors=True,
157
- agent_executor_kwargs={"system_message": system_message_content},
158
- agent_type="openai-tools",
159
- allow_dangerous_code=True
160
- )
161
- print("Pandas DataFrame Agent created.")
162
- return agent
163
-
164
  # --- Gradio Interface Function - this is what the UI calls ---
 
 
165
  def predict_answer(question):
166
- global _df_cache, _load_error_cache # Access the globally cached DataFrame
167
-
168
- if _load_error_cache:
169
- return _load_error_cache
170
  if _df_cache is None:
171
  return "Internal error: DataFrame not loaded. Please check logs."
172
 
173
  try:
174
- # Load LLM and create agent for THIS request within the GPU-allocated context
175
- # This function call will trigger the @spaces.GPU decorator.
176
- current_agent = get_llm_and_agent(_df_cache)
177
- response = current_agent.invoke({"input": question})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  return response['output']
 
179
  except Exception as e:
180
  return f"An error occurred while processing your request: {e}\nPlease try rephrasing your question or check the Space logs for more details."
181
 
182
  # --- Initial setup - These lines run ONCE when the app.py script starts ---
183
- # Only load the DataFrame initially. LLM and agent are loaded per request.
184
  print("Starting initial setup: Loading data...")
185
  _df_cache = load_and_prepare_data_singleton()
186
  print("Initial data setup complete.")
@@ -213,13 +151,13 @@ else:
213
 
214
  def bot_response(history):
215
  query = history[-1][0]
216
- response = predict_answer(query) # This will now trigger the GPU load
217
  history[-1][1] = response
218
  return history
219
 
220
- msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
221
  bot_response, chatbot, chatbot
222
  )
223
  clear.click(lambda: [], None, chatbot, queue=False)
224
 
225
- demo.queue().launch(debug=True)
 
14
 
15
  # --- Global variable for DataFrame only (can be cached globally as it's simple) ---
16
  _df_cache = None
17
+ _load_error_cache = None
18
 
19
  # --- Function to load and prepare the DataFrame (will run once) ---
20
  def load_and_prepare_data_singleton():
21
  global _df_cache, _load_error_cache
22
  if _df_cache is not None:
23
+ return _df_cache
24
  try:
25
  df = pd.read_csv(DATA_FILE_PATH, low_memory=False)
26
  print("IPL.csv loaded successfully.")
 
56
  print(_load_error_cache)
57
  return None
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # --- Gradio Interface Function - this is what the UI calls ---
60
+ # This function is now fully self-contained and decorated with @spaces.GPU
61
+ @spaces.GPU
62
  def predict_answer(question):
63
+ global _df_cache
 
 
 
64
  if _df_cache is None:
65
  return "Internal error: DataFrame not loaded. Please check logs."
66
 
67
  try:
68
+ # Load the LLM and create the agent inside this function call
69
+ print("Loading LLM and creating agent for this request...")
70
+ bnb_config = BitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_quant_type="nf4",
73
+ bnb_4bit_compute_dtype=torch.float16,
74
+ bnb_4bit_use_double_quant=False,
75
+ )
76
+ llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
77
+ if llm_tokenizer.pad_token is None:
78
+ llm_tokenizer.pad_token = llm_tokenizer.eos_token
79
+ llm_model = AutoModelForCausalLM.from_pretrained(
80
+ LLM_MODEL_ID,
81
+ quantization_config=bnb_config,
82
+ torch_dtype=torch.float16,
83
+ device_map="auto",
84
+ trust_remote_code=True,
85
+ )
86
+ llm_pipeline = pipeline(
87
+ "text-generation",
88
+ model=llm_model,
89
+ tokenizer=llm_tokenizer,
90
+ max_new_tokens=1000,
91
+ do_sample=True,
92
+ temperature=0.1,
93
+ top_p=0.9,
94
+ eos_token_id=llm_tokenizer.eos_token_id,
95
+ pad_token_id=llm_tokenizer.pad_token_id,
96
+ )
97
+ llm = HuggingFacePipeline(pipeline=llm_pipeline)
98
+
99
+ system_message_content = """
100
+ You are an expert cricket analyst. You have access to a pandas DataFrame named `df` containing ball-by-ball IPL match data.
101
+ ... (rest of your system message content) ...
102
+ """
103
+ agent = create_pandas_dataframe_agent(
104
+ llm,
105
+ _df_cache, # Pass the globally cached DataFrame
106
+ verbose=True,
107
+ max_iterations=10,
108
+ handle_parsing_errors=True,
109
+ agent_executor_kwargs={"system_message": system_message_content},
110
+ agent_type="openai-tools",
111
+ allow_dangerous_code=True
112
+ )
113
+ print("Pandas DataFrame Agent created.")
114
+
115
+ response = agent.invoke({"input": question})
116
  return response['output']
117
+
118
  except Exception as e:
119
  return f"An error occurred while processing your request: {e}\nPlease try rephrasing your question or check the Space logs for more details."
120
 
121
  # --- Initial setup - These lines run ONCE when the app.py script starts ---
 
122
  print("Starting initial setup: Loading data...")
123
  _df_cache = load_and_prepare_data_singleton()
124
  print("Initial data setup complete.")
 
151
 
152
  def bot_response(history):
153
  query = history[-1][0]
154
+ response = predict_answer(query)
155
  history[-1][1] = response
156
  return history
157
 
158
+ msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=True).then( # Changed queue=False to queue=True
159
  bot_response, chatbot, chatbot
160
  )
161
  clear.click(lambda: [], None, chatbot, queue=False)
162
 
163
+ demo.queue(max_size=20).launch(debug=True)