|
|
import pandas as pd |
|
|
import os |
|
|
from langchain.chat_models import init_chat_model |
|
|
from langchain_community.agent_toolkits import create_sql_agent |
|
|
from langchain_community.utilities import SQLDatabase |
|
|
from sqlalchemy import create_engine |
|
|
import gradio as gr |
|
|
|
|
|
os.environ["GOOGLE_API_KEY"]=os.getenv("GOOGLE_API_KEY") |
|
|
|
|
|
llm = init_chat_model("gemini-2.5-flash", model_provider="google_genai") |
|
|
|
|
|
DATA_FILE = "IPL.csv" |
|
|
|
|
|
|
|
|
|
|
|
def load_df(): |
|
|
df = pd.read_csv(DATA_FILE, low_memory=False) |
|
|
df.columns = df.columns.str.replace(" ", "_").str.lower() |
|
|
if "date" in df.columns: |
|
|
df["date"] = pd.to_datetime(df["date"], errors="coerce") |
|
|
if {"runs_batter", "runs_extras"}.issubset(df.columns): |
|
|
df["runs_batter"] = pd.to_numeric(df["runs_batter"], errors="coerce").fillna(0) |
|
|
df["runs_extras"] = pd.to_numeric(df["runs_extras"], errors="coerce").fillna(0) |
|
|
df["total_runs_this_ball"] = df["runs_batter"] + df["runs_extras"] |
|
|
return df |
|
|
|
|
|
|
|
|
_df = load_df() |
|
|
|
|
|
engine = create_engine("sqlite:///ipl.db") |
|
|
_df.to_sql("ipl", engine, index=False) |
|
|
db = SQLDatabase(engine=engine) |
|
|
print("Db created successfully") |
|
|
|
|
|
|
|
|
def main(query): |
|
|
try: |
|
|
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True) |
|
|
response = agent_executor.invoke({"input": query}) |
|
|
return response["output"] |
|
|
except: |
|
|
return "Failed to fetch the required info" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# ๐ IPL Cricket Analyst") |
|
|
gr.Markdown( |
|
|
"Ask questions about IPL stats from the dataset. Examples:<br>" |
|
|
"`Top 5 batsmen by total runs`<br>" |
|
|
"`Who scored the most in 2023?`<br>" |
|
|
"`Average runs per over in 2022?`" |
|
|
) |
|
|
|
|
|
chatbot = gr.Chatbot(label="Cricket Analyst") |
|
|
msg = gr.Textbox(label="Ask your question here...") |
|
|
clear = gr.Button("Clear") |
|
|
|
|
|
def user_input(m, hist): |
|
|
return "", hist + [[m, None]] |
|
|
|
|
|
def bot_reply(hist): |
|
|
q = hist[-1][0] |
|
|
a = main(q) |
|
|
hist[-1][1] = a |
|
|
return hist |
|
|
|
|
|
msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=True).then( |
|
|
bot_reply, chatbot, chatbot |
|
|
) |
|
|
clear.click(lambda: [], None, chatbot) |
|
|
|
|
|
demo.queue(max_size=20).launch(debug=True) |
|
|
|