jay0911 commited on
Commit
e64fc5c
·
verified ·
1 Parent(s): 065403e

creating a custom wrapper to move inputs to gpu

Browse files
Files changed (1) hide show
  1. app.py +86 -65
app.py CHANGED
@@ -3,35 +3,37 @@ import torch
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForCausalLM,
6
- pipeline,
7
  BitsAndBytesConfig,
 
8
  )
9
  from langchain_experimental.agents import create_pandas_dataframe_agent
10
  from langchain_community.llms import HuggingFacePipeline
11
  import gradio as gr
12
- import spaces # required for ZeroGPU
13
 
14
- # --- Constants ---
15
- LLM_MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
16
- DATA_FILE = "IPL.csv"
17
- MAX_NEW_TOKENS = 256
18
- GPU_DURATION = 120 # seconds for @spaces.GPU
19
 
20
- # --- 1) Load & prepare DataFrame once ---
21
- def load_data():
 
 
 
 
 
22
  df = pd.read_csv(DATA_FILE, low_memory=False)
23
  df.columns = df.columns.str.replace(" ", "_").str.lower()
24
  if "date" in df.columns:
25
  df["date"] = pd.to_datetime(df["date"], errors="coerce")
26
- if all(c in df.columns for c in ("runs_batter", "runs_extras")):
27
  df["runs_batter"] = pd.to_numeric(df["runs_batter"], errors="coerce").fillna(0)
28
- df["runs_extras"] = pd.to_numeric(df["runs_extras"], errors="coerce").fillna(0)
29
  df["total_runs_this_ball"] = df["runs_batter"] + df["runs_extras"]
30
  return df
31
 
32
- _df = load_data()
33
 
34
- # --- 2) Instantiate tokenizer, model, pipeline, and agent globally ---
 
 
 
35
  bnb_config = BitsAndBytesConfig(
36
  load_in_4bit=True,
37
  bnb_4bit_quant_type="nf4",
@@ -40,80 +42,99 @@ bnb_config = BitsAndBytesConfig(
40
  )
41
 
42
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
43
- if tokenizer.pad_token is None:
44
- tokenizer.pad_token = tokenizer.eos_token
45
 
46
  model = AutoModelForCausalLM.from_pretrained(
47
  LLM_MODEL_ID,
48
- quantization_config=bnb_config,
49
  torch_dtype=torch.float16,
 
 
50
  trust_remote_code=True,
51
  )
52
- # model.to("cuda")
53
-
54
- pipe = pipeline(
55
- "text-generation",
56
- model=model,
57
- tokenizer=tokenizer,
58
- # device=0, # <— ensure GPU inference
59
- max_new_tokens=MAX_NEW_TOKENS,
60
- do_sample=True,
61
- temperature=0.1,
62
- top_p=0.9,
63
- eos_token_id=tokenizer.eos_token_id,
64
- pad_token_id=tokenizer.pad_token_id,
65
- )
66
 
67
- hf_llm = HuggingFacePipeline(pipeline=pipe)
68
- # (NO hf_llm.to("cuda"); the pipeline already handles device)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  system_message = """
71
- You are an expert cricket analyst. You have access to a pandas DataFrame named `df`
72
- containing ball-by-ball IPL match data. Use Python (pandas) to answer queries about IPL stats
73
- as efficiently as possible. Do not import extra libraries.
74
  """
 
 
75
  agent = create_pandas_dataframe_agent(
76
- hf_llm,
77
  _df,
78
  verbose=False,
79
- max_iterations=5,
80
  handle_parsing_errors=True,
81
  agent_executor_kwargs={"system_message": system_message},
82
- agent_type="openai-tools",
83
  allow_dangerous_code=True,
84
  )
85
 
86
- # --- 3) Define inference function (GPU-enabled) ---
87
- @spaces.GPU(duration=GPU_DURATION)
88
- def run_inference(question: str) -> str:
89
- torch.cuda.empty_cache() # free up cached memory
90
- result = agent.invoke({"input": question})
91
- return result.get("output", "No output returned.")
92
 
93
- # --- 4) Build Gradio app ---
94
- def bot_response(history):
95
- query = history[-1][0]
 
96
  try:
97
- answer = run_inference(query)
 
98
  except Exception as e:
99
- answer = f"Error during inference: {e}"
100
- history[-1][1] = answer
101
- return history
102
 
 
103
  with gr.Blocks() as demo:
104
- gr.Markdown("# IPL Cricket Data Agent")
105
- gr.Markdown("Ask me anything about the IPL dataset (e.g., top run-scorers, match outcomes, averages).")
106
- chatbot = gr.Chatbot()
107
- user_input = gr.Textbox(placeholder="Type your question here...")
108
- clear_btn = gr.Button("Clear")
109
-
110
- user_input.submit(
111
- lambda msg, chat: (None, chat + [[msg, None]]),
112
- [user_input, chatbot],
113
- [user_input, chatbot],
114
- queue=True
115
- ).then(bot_response, chatbot, chatbot)
116
-
117
- clear_btn.click(lambda: [], None, chatbot)
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  demo.queue(max_size=20).launch(debug=True)
 
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForCausalLM,
 
6
  BitsAndBytesConfig,
7
+ pipeline
8
  )
9
  from langchain_experimental.agents import create_pandas_dataframe_agent
10
  from langchain_community.llms import HuggingFacePipeline
11
  import gradio as gr
12
+ import spaces
13
 
 
 
 
 
 
14
 
15
+ # --- Config ---
16
+ LLM_MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
17
+ DATA_FILE = "IPL.csv"
18
+
19
+
20
+ # --- Load IPL Data ---
21
+ def load_df():
22
  df = pd.read_csv(DATA_FILE, low_memory=False)
23
  df.columns = df.columns.str.replace(" ", "_").str.lower()
24
  if "date" in df.columns:
25
  df["date"] = pd.to_datetime(df["date"], errors="coerce")
26
+ if {"runs_batter", "runs_extras"}.issubset(df.columns):
27
  df["runs_batter"] = pd.to_numeric(df["runs_batter"], errors="coerce").fillna(0)
28
+ df["runs_extras"] = pd.to_numeric(df["runs_extras"], errors="coerce").fillna(0)
29
  df["total_runs_this_ball"] = df["runs_batter"] + df["runs_extras"]
30
  return df
31
 
 
32
 
33
+ _df = load_df()
34
+
35
+
36
+ # --- Load Quantized Model ---
37
  bnb_config = BitsAndBytesConfig(
38
  load_in_4bit=True,
39
  bnb_4bit_quant_type="nf4",
 
42
  )
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
45
+ tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
 
46
 
47
  model = AutoModelForCausalLM.from_pretrained(
48
  LLM_MODEL_ID,
 
49
  torch_dtype=torch.float16,
50
+ device_map="auto",
51
+ quantization_config=bnb_config,
52
  trust_remote_code=True,
53
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
 
 
55
 
56
+ # --- LLM Wrapper for LangChain ---
57
+ class MyLLMWrapper:
58
+ def __init__(self):
59
+ self.tokenizer = tokenizer
60
+ self.model = model
61
+
62
+ def invoke(self, input_str):
63
+ return self.__call__(input_str)
64
+
65
+ def __call__(self, input_str):
66
+ inputs = self.tokenizer(input_str, return_tensors="pt").to(self.model.device)
67
+ with torch.no_grad():
68
+ outputs = self.model.generate(
69
+ **inputs,
70
+ max_new_tokens=256,
71
+ do_sample=True,
72
+ temperature=0.1,
73
+ top_p=0.9,
74
+ eos_token_id=self.tokenizer.eos_token_id,
75
+ pad_token_id=self.tokenizer.pad_token_id
76
+ )
77
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
78
+
79
+
80
+ llm = MyLLMWrapper()
81
+
82
+
83
+ # --- System Prompt for the Agent ---
84
  system_message = """
85
+ You are an expert IPL cricket analyst. You have access to a pandas DataFrame named `df` that contains ball-by-ball IPL match data.
86
+ Answer all questions using pandas logic, match stats, and accurate calculations.
 
87
  """
88
+
89
+ # --- LangChain Agent ---
90
  agent = create_pandas_dataframe_agent(
91
+ llm,
92
  _df,
93
  verbose=False,
 
94
  handle_parsing_errors=True,
95
  agent_executor_kwargs={"system_message": system_message},
96
+ agent_type="openai-tools", # Most compatible with Hugging Face models
97
  allow_dangerous_code=True,
98
  )
99
 
 
 
 
 
 
 
100
 
101
+ # --- Inference Function ---
102
+ @spaces.GPU(duration=120)
103
+ def predict_answer(question):
104
+ torch.cuda.empty_cache()
105
  try:
106
+ res = agent.invoke({"input": question})
107
+ return res.get("output", "No response generated.")
108
  except Exception as e:
109
+ return f"Error during inference: {e}"
110
+
 
111
 
112
+ # --- Gradio UI ---
113
  with gr.Blocks() as demo:
114
+ gr.Markdown("# 🏏 IPL Cricket Analyst")
115
+ gr.Markdown(
116
+ "Ask questions about IPL stats from the dataset. Examples:<br>"
117
+ "`Top 5 batsmen by total runs`<br>"
118
+ "`Who scored the most in 2023?`<br>"
119
+ "`Average runs per over in 2022?`"
120
+ )
121
+
122
+ chatbot = gr.Chatbot(label="Cricket Analyst")
123
+ msg = gr.Textbox(label="Ask your question here...")
124
+ clear = gr.Button("Clear")
125
+
126
+ def user_input(m, hist):
127
+ return "", hist + [[m, None]]
128
+
129
+ def bot_reply(hist):
130
+ q = hist[-1][0]
131
+ a = predict_answer(q)
132
+ hist[-1][1] = a
133
+ return hist
134
+
135
+ msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=True).then(
136
+ bot_reply, chatbot, chatbot
137
+ )
138
+ clear.click(lambda: [], None, chatbot)
139
 
140
  demo.queue(max_size=20).launch(debug=True)