jay0911 commited on
Commit
c5f19aa
·
verified ·
1 Parent(s): 179255e

updated with langchain sql agent

Browse files
Files changed (1) hide show
  1. app.py +18 -90
app.py CHANGED
@@ -1,19 +1,15 @@
1
  import pandas as pd
2
- import torch
3
- from transformers import (
4
- AutoTokenizer,
5
- AutoModelForCausalLM,
6
- BitsAndBytesConfig,
7
- pipeline
8
- )
9
- from langchain_experimental.agents import create_pandas_dataframe_agent
10
- from langchain_community.llms import HuggingFacePipeline
11
  import gradio as gr
12
- import spaces
13
-
14
 
 
15
  # --- Config ---
16
- LLM_MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
17
  DATA_FILE = "IPL.csv"
18
 
19
 
@@ -32,86 +28,18 @@ def load_df():
32
 
33
  _df = load_df()
34
 
35
-
36
- # --- Load Quantized Model ---
37
- bnb_config = BitsAndBytesConfig(
38
- load_in_4bit=True,
39
- bnb_4bit_quant_type="nf4",
40
- bnb_4bit_compute_dtype=torch.float16,
41
- bnb_4bit_use_double_quant=False,
42
- )
43
-
44
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
45
- tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
46
-
47
- model = AutoModelForCausalLM.from_pretrained(
48
- LLM_MODEL_ID,
49
- torch_dtype=torch.float16,
50
- device_map="auto",
51
- quantization_config=bnb_config,
52
- trust_remote_code=True,
53
- )
54
-
55
-
56
- # --- LLM Wrapper for LangChain ---
57
- class MyLLMWrapper:
58
- def __init__(self):
59
- self.tokenizer = tokenizer
60
- self.model = model
61
-
62
- def invoke(self, input_str):
63
- return self.__call__(input_str)
64
-
65
- def __call__(self, input_str):
66
- inputs = self.tokenizer(input_str, return_tensors="pt").to(self.model.device)
67
- with torch.no_grad():
68
- outputs = self.model.generate(
69
- **inputs,
70
- max_new_tokens=256,
71
- do_sample=True,
72
- temperature=0.1,
73
- top_p=0.9,
74
- eos_token_id=self.tokenizer.eos_token_id,
75
- pad_token_id=self.tokenizer.pad_token_id
76
- )
77
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
78
-
79
-
80
- llm = MyLLMWrapper()
81
-
82
-
83
- # --- System Prompt for the Agent ---
84
- system_message = """
85
- You are an expert IPL cricket analyst. You have access to a pandas DataFrame named `df` that contains ball-by-ball IPL match data.
86
- Answer all questions using pandas logic, match stats, and accurate calculations.
87
- """
88
-
89
- # --- LangChain Agent ---
90
- agent = create_pandas_dataframe_agent(
91
- llm,
92
- _df,
93
- verbose=False,
94
- max_execution_time=None,
95
- early_stopping_method="force",
96
- include_df_in_prompt=True,
97
- number_of_head_rows=5,
98
- extra_tools=(),
99
- # handle_parsing_errors=True,
100
- agent_executor_kwargs={"system_message": system_message},
101
- agent_type="openai-tools", # Most compatible with Hugging Face models
102
- allow_dangerous_code=True,
103
- )
104
 
105
 
106
- # --- Inference Function ---
107
- @spaces.GPU(duration=120)
108
- def predict_answer(question):
109
- torch.cuda.empty_cache()
110
  try:
111
- res = agent.invoke({"input": question})
112
- return res.get("output", "No response generated.")
113
- except Exception as e:
114
- return f" Error during inference: {e}"
115
 
116
 
117
  # --- Gradio UI ---
@@ -133,7 +61,7 @@ with gr.Blocks() as demo:
133
 
134
  def bot_reply(hist):
135
  q = hist[-1][0]
136
- a = predict_answer(q)
137
  hist[-1][1] = a
138
  return hist
139
 
 
1
  import pandas as pd
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langchain.chat_models import init_chat_model
5
+ from langchain_community.agent_toolkits import create_sql_agent
6
+ from langchain_community.utilities import SQLDatabase
7
+ from sqlalchemy import create_engine
 
 
 
8
  import gradio as gr
9
+ load_dotenv()
 
10
 
11
+ llm = init_chat_model("gemini-2.5-flash", model_provider="google_genai")
12
  # --- Config ---
 
13
  DATA_FILE = "IPL.csv"
14
 
15
 
 
28
 
29
  _df = load_df()
30
 
31
+ engine = create_engine("sqlite:///ipl.db")
32
+ _df.to_sql("ipl", engine, index=False)
33
+ db = SQLDatabase(engine=engine)
34
+ print("Db created successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
+ def main(query):
 
 
 
38
  try:
39
+ agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
40
+ return agent_executor.invoke({"input": query})
41
+ except:
42
+ return "Failed to fetch the required info"
43
 
44
 
45
  # --- Gradio UI ---
 
61
 
62
  def bot_reply(hist):
63
  q = hist[-1][0]
64
+ a = main(q)
65
  hist[-1][1] = a
66
  return hist
67