| | from fastapi import FastAPI, HTTPException
|
| | from pydantic import BaseModel
|
| | import sqlite3
|
| | import pandas as pd
|
| | import os
|
| | from dotenv import load_dotenv
|
| | import google.generativeai as genai
|
| |
|
| | app = FastAPI()
|
| |
|
| |
|
| | load_dotenv()
|
| | genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
|
| |
|
| | class Query(BaseModel):
|
| | question: str
|
| | data_source: str
|
| |
|
| | def get_gemini_response(question, prompt):
|
| | model = genai.GenerativeModel('gemini-pro')
|
| | response = model.generate_content([prompt, question])
|
| | return response.text
|
| |
|
| | def get_csv_columns():
|
| | df = pd.read_csv('employee.csv')
|
| | return df.columns.tolist()
|
| |
|
| | csv_columns = get_csv_columns()
|
| |
|
| | sql_prompt = """
|
| | You are an expert in converting English questions to SQL code!
|
| | The SQL database has the name STUDENT and has the following Columns - NAME, CLASS, SECTION
|
| |
|
| | For example:
|
| | - How many entries of records are present? SQL command: SELECT COUNT(*) FROM STUDENT;
|
| | - Tell me all the students studying in Data Science class? SQL command: SELECT * FROM STUDENT where CLASS="Data Science";
|
| |
|
| | Also, the SQL code should not have ''' in the beginning or at the end, and SQL word in output.
|
| | Ensure that you only generate valid SQL queries, not pandas or Python code.
|
| | """
|
| |
|
| | csv_prompt = f"""
|
| | You are an expert in analyzing CSV data and converting English questions to pandas query syntax.
|
| | The CSV file is named 'employee.csv' and contains employee information.
|
| | The available columns in the CSV file are: {', '.join(csv_columns)}
|
| |
|
| | For example:
|
| | - How many employees are there? Pandas query: len(df)
|
| | - List all employees in the Sales department. Pandas query: df[df['Department'] == 'Sales']
|
| | - Show employees with a specific ID. Pandas query: df[df['ID'] == specific_id]
|
| |
|
| | Provide only the pandas query syntax without any additional explanation or markdown formatting.
|
| | Do not include 'df = ' or any variable assignment in your response.
|
| | Make sure to use only the columns that are available in the CSV file.
|
| | Ensure that you only generate valid pandas queries, not SQL or other types of code.
|
| | """
|
| |
|
| | def execute_sql_query(query):
|
| | conn = sqlite3.connect('student.db')
|
| | try:
|
| | cursor = conn.cursor()
|
| | cursor.execute(query)
|
| | result = cursor.fetchall()
|
| | return result
|
| | except sqlite3.Error as e:
|
| | raise HTTPException(status_code=400, detail=f"SQL Error: {str(e)}")
|
| | finally:
|
| | conn.close()
|
| |
|
| | def execute_pandas_query(query):
|
| | df = pd.read_csv('employee.csv')
|
| | try:
|
| | result = eval(query, {'df': df, 'pd': pd})
|
| | if isinstance(result, pd.DataFrame):
|
| | return result.to_dict(orient='records')
|
| | elif isinstance(result, pd.Series):
|
| | return result.to_dict()
|
| | else:
|
| | return result
|
| | except Exception as e:
|
| | raise HTTPException(status_code=400, detail=f"Pandas Error: {str(e)}")
|
| |
|
| | @app.post("/query")
|
| | async def process_query(query: Query):
|
| | if query.data_source == "SQL Database":
|
| | ai_response = get_gemini_response(query.question, sql_prompt)
|
| | try:
|
| | result = execute_sql_query(ai_response)
|
| | return {"query": ai_response, "result": result}
|
| | except HTTPException as e:
|
| | raise HTTPException(status_code=400, detail=f"Error in SQL query: {e.detail}")
|
| | else:
|
| | ai_response = get_gemini_response(query.question, csv_prompt)
|
| | try:
|
| | result = execute_pandas_query(ai_response)
|
| | return {"query": ai_response, "result": result, "columns": csv_columns}
|
| | except HTTPException as e:
|
| | raise HTTPException(status_code=400, detail=f"Error in pandas query: {e.detail}") |