Update app.py
Browse files
app.py
CHANGED
|
@@ -14,13 +14,13 @@ DATA_FILE_PATH = "IPL.csv"
|
|
| 14 |
|
| 15 |
# --- Global variable for DataFrame only (can be cached globally as it's simple) ---
|
| 16 |
_df_cache = None
|
| 17 |
-
_load_error_cache = None
|
| 18 |
|
| 19 |
# --- Function to load and prepare the DataFrame (will run once) ---
|
| 20 |
def load_and_prepare_data_singleton():
|
| 21 |
global _df_cache, _load_error_cache
|
| 22 |
if _df_cache is not None:
|
| 23 |
-
return _df_cache
|
| 24 |
try:
|
| 25 |
df = pd.read_csv(DATA_FILE_PATH, low_memory=False)
|
| 26 |
print("IPL.csv loaded successfully.")
|
|
@@ -56,131 +56,69 @@ def load_and_prepare_data_singleton():
|
|
| 56 |
print(_load_error_cache)
|
| 57 |
return None
|
| 58 |
|
| 59 |
-
# --- NEW: Function to load LLM and create Agent (per request, decorated with @spaces.GPU) ---
|
| 60 |
-
# This function is now responsible for loading the LLM and creating the agent
|
| 61 |
-
# within the GPU worker process for each prediction.
|
| 62 |
-
@spaces.GPU # <--- Apply @spaces.GPU here
|
| 63 |
-
def get_llm_and_agent(df):
|
| 64 |
-
if df is None:
|
| 65 |
-
raise ValueError("DataFrame not loaded, cannot create agent.")
|
| 66 |
-
|
| 67 |
-
# These checks are now within the GPU-allocated context
|
| 68 |
-
if not torch.cuda.is_available():
|
| 69 |
-
raise RuntimeError("Error: CUDA (GPU) is not available. This model requires a GPU.")
|
| 70 |
-
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 71 |
-
print(f"CUDA device count: {torch.cuda.device_count()}")
|
| 72 |
-
print(f"Current CUDA device: {torch.cuda.current_device()}")
|
| 73 |
-
|
| 74 |
-
bnb_config = BitsAndBytesConfig(
|
| 75 |
-
load_in_4bit=True,
|
| 76 |
-
bnb_4bit_quant_type="nf4",
|
| 77 |
-
bnb_4bit_compute_dtype=torch.float16,
|
| 78 |
-
bnb_4bit_use_double_quant=False,
|
| 79 |
-
)
|
| 80 |
-
print(f"Loading LLM: {LLM_MODEL_ID}...")
|
| 81 |
-
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
|
| 82 |
-
if llm_tokenizer.pad_token is None:
|
| 83 |
-
llm_tokenizer.pad_token = llm_tokenizer.eos_token
|
| 84 |
-
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 85 |
-
LLM_MODEL_ID,
|
| 86 |
-
quantization_config=bnb_config,
|
| 87 |
-
torch_dtype=torch.float16,
|
| 88 |
-
device_map="auto",
|
| 89 |
-
trust_remote_code=True,
|
| 90 |
-
)
|
| 91 |
-
llm_pipeline = pipeline(
|
| 92 |
-
"text-generation",
|
| 93 |
-
model=llm_model,
|
| 94 |
-
tokenizer=llm_tokenizer,
|
| 95 |
-
max_new_tokens=1000,
|
| 96 |
-
do_sample=True,
|
| 97 |
-
temperature=0.1,
|
| 98 |
-
top_p=0.9,
|
| 99 |
-
eos_token_id=llm_tokenizer.eos_token_id,
|
| 100 |
-
pad_token_id=llm_tokenizer.pad_token_id,
|
| 101 |
-
)
|
| 102 |
-
llm = HuggingFacePipeline(pipeline=llm_pipeline)
|
| 103 |
-
print("LLM loaded and configured.")
|
| 104 |
-
|
| 105 |
-
system_message_content = """
|
| 106 |
-
You are an expert cricket analyst. You have access to a pandas DataFrame named `df` containing ball-by-ball IPL match data.
|
| 107 |
-
The DataFrame has the following relevant columns for querying:
|
| 108 |
-
- 'match_id': Unique ID for each match.
|
| 109 |
-
- 'date': Date of the match (datetime object).
|
| 110 |
-
- 'match_type': Type of match (e.g., T20).
|
| 111 |
-
- 'event_name': Name of the event (e.g., Indian Premier League).
|
| 112 |
-
- 'innings': The innings number (1 or 2).
|
| 113 |
-
- 'batting_team': The team currently batting.
|
| 114 |
-
- 'bowling_team': The team currently bowling.
|
| 115 |
-
- 'over', 'ball', 'ball_no': Details about the specific ball.
|
| 116 |
-
- 'batter': The batsman on strike.
|
| 117 |
-
- 'bat_pos': Batting position.
|
| 118 |
-
- 'runs_batter': Runs scored by the batsman on that ball (off the bat).
|
| 119 |
-
- 'balls_faced': Balls faced by the batter up to that point in the innings.
|
| 120 |
-
- 'bowler': The bowler who bowled that ball.
|
| 121 |
-
- 'valid_ball': Whether the ball was a valid delivery.
|
| 122 |
-
- 'runs_extras': Runs scored as extras (wides, no-balls, byes, leg-byes, penalty).
|
| 123 |
-
- 'runs_total': Total runs scored on that ball (runs_batter + runs_extras).
|
| 124 |
-
- 'runs_bowler': Runs conceded by the bowler on that ball.
|
| 125 |
-
- 'extra_type': Type of extra (e.g., 'wides', 'noball').
|
| 126 |
-
- 'non_striker': The non_striker batsman.
|
| 127 |
-
- 'wicket_kind': Type of dismissal (e.g., 'bowled', 'caught').
|
| 128 |
-
- 'player_out': The player dismissed.
|
| 129 |
-
- 'fielders': Fielders involved in the dismissal.
|
| 130 |
-
- 'player_of_match': Player of the match.
|
| 131 |
-
- 'match_won_by': The team that won the match.
|
| 132 |
-
- 'win_outcome': How the match was won (e.g., 'runs', 'wickets').
|
| 133 |
-
- 'toss_winner': The team that won the toss.
|
| 134 |
-
- 'toss_decision': What the toss winner decided to do (bat or bowl).
|
| 135 |
-
- 'venue': Match venue.
|
| 136 |
-
- 'city': City where the match was played.
|
| 137 |
-
- 'year', 'season': Year and IPL season.
|
| 138 |
-
- 'gender', 'team_type', 'superover_winner', 'result_type', 'method': Other match details.
|
| 139 |
-
- 'team_runs', 'team_balls', 'team_wicket': Team's total runs, balls, wickets.
|
| 140 |
-
- 'new_batter', 'batter_runs', 'batter_balls', 'bowler_wicket': Aggregated stats.
|
| 141 |
-
- 'batting_partners', 'next_batter', 'striker_out': More granular details.
|
| 142 |
-
- 'total_runs_this_ball': (NEW COLUMN YOU ADDED) Sum of 'runs_batter' and 'runs_extras' for that specific ball.
|
| 143 |
-
Your goal is to answer user questions about IPL cricket statistics by writing and executing pandas code on the `df` DataFrame.
|
| 144 |
-
When performing calculations, be precise. For averages, ensure you handle division by zero (e.g., by checking if denominator is zero or using `df.sum() / df.count()` for means).
|
| 145 |
-
If the answer is a numerical value, just output the number. If it's a specific player or team name, output just the name.
|
| 146 |
-
If you cannot find the answer in the DataFrame, state that you don't know or that the information is not available.
|
| 147 |
-
Avoid providing general cricket knowledge not derivable from the DataFrame.
|
| 148 |
-
Focus solely on extracting information from the 'df' DataFrame. When answering questions about totals or aggregations, consider all relevant rows unless a specific filter (like season or match) is provided.
|
| 149 |
-
Always try to provide a concise answer directly from the data.
|
| 150 |
-
"""
|
| 151 |
-
agent = create_pandas_dataframe_agent(
|
| 152 |
-
llm,
|
| 153 |
-
df,
|
| 154 |
-
verbose=True,
|
| 155 |
-
max_iterations=10,
|
| 156 |
-
handle_parsing_errors=True,
|
| 157 |
-
agent_executor_kwargs={"system_message": system_message_content},
|
| 158 |
-
agent_type="openai-tools",
|
| 159 |
-
allow_dangerous_code=True
|
| 160 |
-
)
|
| 161 |
-
print("Pandas DataFrame Agent created.")
|
| 162 |
-
return agent
|
| 163 |
-
|
| 164 |
# --- Gradio Interface Function - this is what the UI calls ---
|
|
|
|
|
|
|
| 165 |
def predict_answer(question):
|
| 166 |
-
global _df_cache
|
| 167 |
-
|
| 168 |
-
if _load_error_cache:
|
| 169 |
-
return _load_error_cache
|
| 170 |
if _df_cache is None:
|
| 171 |
return "Internal error: DataFrame not loaded. Please check logs."
|
| 172 |
|
| 173 |
try:
|
| 174 |
-
# Load LLM and create agent
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
return response['output']
|
|
|
|
| 179 |
except Exception as e:
|
| 180 |
return f"An error occurred while processing your request: {e}\nPlease try rephrasing your question or check the Space logs for more details."
|
| 181 |
|
| 182 |
# --- Initial setup - These lines run ONCE when the app.py script starts ---
|
| 183 |
-
# Only load the DataFrame initially. LLM and agent are loaded per request.
|
| 184 |
print("Starting initial setup: Loading data...")
|
| 185 |
_df_cache = load_and_prepare_data_singleton()
|
| 186 |
print("Initial data setup complete.")
|
|
@@ -213,13 +151,13 @@ else:
|
|
| 213 |
|
| 214 |
def bot_response(history):
|
| 215 |
query = history[-1][0]
|
| 216 |
-
response = predict_answer(query)
|
| 217 |
history[-1][1] = response
|
| 218 |
return history
|
| 219 |
|
| 220 |
-
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=
|
| 221 |
bot_response, chatbot, chatbot
|
| 222 |
)
|
| 223 |
clear.click(lambda: [], None, chatbot, queue=False)
|
| 224 |
|
| 225 |
-
demo.queue().launch(debug=True)
|
|
|
|
| 14 |
|
| 15 |
# --- Global variable for DataFrame only (can be cached globally as it's simple) ---
|
| 16 |
_df_cache = None
|
| 17 |
+
_load_error_cache = None
|
| 18 |
|
| 19 |
# --- Function to load and prepare the DataFrame (will run once) ---
|
| 20 |
def load_and_prepare_data_singleton():
|
| 21 |
global _df_cache, _load_error_cache
|
| 22 |
if _df_cache is not None:
|
| 23 |
+
return _df_cache
|
| 24 |
try:
|
| 25 |
df = pd.read_csv(DATA_FILE_PATH, low_memory=False)
|
| 26 |
print("IPL.csv loaded successfully.")
|
|
|
|
| 56 |
print(_load_error_cache)
|
| 57 |
return None
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# --- Gradio Interface Function - this is what the UI calls ---
|
| 60 |
+
# This function is now fully self-contained and decorated with @spaces.GPU
|
| 61 |
+
@spaces.GPU
|
| 62 |
def predict_answer(question):
|
| 63 |
+
global _df_cache
|
|
|
|
|
|
|
|
|
|
| 64 |
if _df_cache is None:
|
| 65 |
return "Internal error: DataFrame not loaded. Please check logs."
|
| 66 |
|
| 67 |
try:
|
| 68 |
+
# Load the LLM and create the agent inside this function call
|
| 69 |
+
print("Loading LLM and creating agent for this request...")
|
| 70 |
+
bnb_config = BitsAndBytesConfig(
|
| 71 |
+
load_in_4bit=True,
|
| 72 |
+
bnb_4bit_quant_type="nf4",
|
| 73 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 74 |
+
bnb_4bit_use_double_quant=False,
|
| 75 |
+
)
|
| 76 |
+
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
|
| 77 |
+
if llm_tokenizer.pad_token is None:
|
| 78 |
+
llm_tokenizer.pad_token = llm_tokenizer.eos_token
|
| 79 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 80 |
+
LLM_MODEL_ID,
|
| 81 |
+
quantization_config=bnb_config,
|
| 82 |
+
torch_dtype=torch.float16,
|
| 83 |
+
device_map="auto",
|
| 84 |
+
trust_remote_code=True,
|
| 85 |
+
)
|
| 86 |
+
llm_pipeline = pipeline(
|
| 87 |
+
"text-generation",
|
| 88 |
+
model=llm_model,
|
| 89 |
+
tokenizer=llm_tokenizer,
|
| 90 |
+
max_new_tokens=1000,
|
| 91 |
+
do_sample=True,
|
| 92 |
+
temperature=0.1,
|
| 93 |
+
top_p=0.9,
|
| 94 |
+
eos_token_id=llm_tokenizer.eos_token_id,
|
| 95 |
+
pad_token_id=llm_tokenizer.pad_token_id,
|
| 96 |
+
)
|
| 97 |
+
llm = HuggingFacePipeline(pipeline=llm_pipeline)
|
| 98 |
+
|
| 99 |
+
system_message_content = """
|
| 100 |
+
You are an expert cricket analyst. You have access to a pandas DataFrame named `df` containing ball-by-ball IPL match data.
|
| 101 |
+
... (rest of your system message content) ...
|
| 102 |
+
"""
|
| 103 |
+
agent = create_pandas_dataframe_agent(
|
| 104 |
+
llm,
|
| 105 |
+
_df_cache, # Pass the globally cached DataFrame
|
| 106 |
+
verbose=True,
|
| 107 |
+
max_iterations=10,
|
| 108 |
+
handle_parsing_errors=True,
|
| 109 |
+
agent_executor_kwargs={"system_message": system_message_content},
|
| 110 |
+
agent_type="openai-tools",
|
| 111 |
+
allow_dangerous_code=True
|
| 112 |
+
)
|
| 113 |
+
print("Pandas DataFrame Agent created.")
|
| 114 |
+
|
| 115 |
+
response = agent.invoke({"input": question})
|
| 116 |
return response['output']
|
| 117 |
+
|
| 118 |
except Exception as e:
|
| 119 |
return f"An error occurred while processing your request: {e}\nPlease try rephrasing your question or check the Space logs for more details."
|
| 120 |
|
| 121 |
# --- Initial setup - These lines run ONCE when the app.py script starts ---
|
|
|
|
| 122 |
print("Starting initial setup: Loading data...")
|
| 123 |
_df_cache = load_and_prepare_data_singleton()
|
| 124 |
print("Initial data setup complete.")
|
|
|
|
| 151 |
|
| 152 |
def bot_response(history):
|
| 153 |
query = history[-1][0]
|
| 154 |
+
response = predict_answer(query)
|
| 155 |
history[-1][1] = response
|
| 156 |
return history
|
| 157 |
|
| 158 |
+
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=True).then( # Changed queue=False to queue=True
|
| 159 |
bot_response, chatbot, chatbot
|
| 160 |
)
|
| 161 |
clear.click(lambda: [], None, chatbot, queue=False)
|
| 162 |
|
| 163 |
+
demo.queue(max_size=20).launch(debug=True)
|