Spaces:
Sleeping
Sleeping
Commit ·
c9e9e4d
1
Parent(s): 25aa36b
added gemini model
Browse files- app/services/sql_agent.py +8 -5
- pyproject.toml +2 -1
- requirements.txt +2 -1
- uv.lock +0 -0
app/services/sql_agent.py
CHANGED
|
@@ -352,16 +352,19 @@ from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
| 352 |
from langchain_groq import ChatGroq
|
| 353 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 354 |
from langchain_core.prompts import ChatPromptTemplate
|
| 355 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
| 356 |
from langgraph.graph import StateGraph, END, MessagesState
|
| 357 |
from typing import TypedDict, Annotated, List, Literal, Dict, Any
|
|
|
|
|
|
|
| 358 |
from dotenv import load_dotenv
|
| 359 |
load_dotenv()
|
| 360 |
import os
|
| 361 |
-
os.environ["GROQ_API_KEY"]=os.getenv("GROQ_API_KEY")
|
|
|
|
| 362 |
|
| 363 |
class SQLAgent:
|
| 364 |
-
def __init__(self, model="
|
| 365 |
|
| 366 |
# Initialize instance variables
|
| 367 |
self.db = None
|
|
@@ -373,8 +376,8 @@ class SQLAgent:
|
|
| 373 |
self.app = None
|
| 374 |
|
| 375 |
# Setting up LLM
|
| 376 |
-
self.llm = ChatGroq(model=model,api_key = os.getenv("GROQ_API_KEY"))
|
| 377 |
-
|
| 378 |
# Register the tool method
|
| 379 |
self.query_to_database = self._create_query_tool()
|
| 380 |
|
|
|
|
| 352 |
from langchain_groq import ChatGroq
|
| 353 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 354 |
from langchain_core.prompts import ChatPromptTemplate
|
| 355 |
+
# from langchain_core.pydantic_v1 import BaseModel, Field
|
| 356 |
from langgraph.graph import StateGraph, END, MessagesState
|
| 357 |
from typing import TypedDict, Annotated, List, Literal, Dict, Any
|
| 358 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 359 |
+
|
| 360 |
from dotenv import load_dotenv
|
| 361 |
load_dotenv()
|
| 362 |
import os
|
| 363 |
+
# os.environ["GROQ_API_KEY"]=os.getenv("GROQ_API_KEY")
|
| 364 |
+
os.environ["GEMINI_API_KEY"]=os.getenv("GEMINI_API_KEY")
|
| 365 |
|
| 366 |
class SQLAgent:
|
| 367 |
+
def __init__(self, model="gemma2-9b-it"):
|
| 368 |
|
| 369 |
# Initialize instance variables
|
| 370 |
self.db = None
|
|
|
|
| 376 |
self.app = None
|
| 377 |
|
| 378 |
# Setting up LLM
|
| 379 |
+
# self.llm = ChatGroq(model=model,api_key = os.getenv("GROQ_API_KEY"))
|
| 380 |
+
self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite", google_api_key=os.environ["GEMINI_API_KEY"])
|
| 381 |
# Register the tool method
|
| 382 |
self.query_to_database = self._create_query_tool()
|
| 383 |
|
pyproject.toml
CHANGED
|
@@ -5,13 +5,14 @@ description = "Add your description here"
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.13"
|
| 7 |
dependencies = [
|
| 8 |
-
"bcrypt
|
| 9 |
"fastapi>=0.116.1",
|
| 10 |
"ipykernel>=6.29.5",
|
| 11 |
"ipython>=9.4.0",
|
| 12 |
"langchain>=0.3.26",
|
| 13 |
"langchain-community>=0.3.27",
|
| 14 |
"langchain-core>=0.3.68",
|
|
|
|
| 15 |
"langchain-groq>=0.3.6",
|
| 16 |
"langgraph>=0.5.3",
|
| 17 |
"pandas>=2.3.1",
|
|
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.13"
|
| 7 |
dependencies = [
|
| 8 |
+
"bcrypt==4.3.0",
|
| 9 |
"fastapi>=0.116.1",
|
| 10 |
"ipykernel>=6.29.5",
|
| 11 |
"ipython>=9.4.0",
|
| 12 |
"langchain>=0.3.26",
|
| 13 |
"langchain-community>=0.3.27",
|
| 14 |
"langchain-core>=0.3.68",
|
| 15 |
+
"langchain-google-genai>=2.1.9",
|
| 16 |
"langchain-groq>=0.3.6",
|
| 17 |
"langgraph>=0.5.3",
|
| 18 |
"pandas>=2.3.1",
|
requirements.txt
CHANGED
|
@@ -15,4 +15,5 @@ ipykernel
|
|
| 15 |
passlib
|
| 16 |
python-multipart
|
| 17 |
bcrypt==4.3.0
|
| 18 |
-
psycopg2-binary
|
|
|
|
|
|
| 15 |
passlib
|
| 16 |
python-multipart
|
| 17 |
bcrypt==4.3.0
|
| 18 |
+
psycopg2-binary
|
| 19 |
+
langchain-google-genai==2.1.9
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|