daai_rangdong / api_v0.py
tlong-ds's picture
update: api now read schema.txt
df45eec
from fastapi import FastAPI, Body
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.utilities import SQLDatabase
from langgraph.prebuilt import create_react_agent
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from typing import List, Optional
import json
import difflib
import os
import sys
from utils import *
from dotenv import load_dotenv
load_dotenv()
app = FastAPI()
# Code
GEMINI_MODELS = ["gemini-2.0-flash"]
OPEN_AI_MODELS = ["gpt-4o", "gpt-4o-mini", "o3-mini"]
ANTHROPIC_AI_MODELS = [
os.getenv('CLAUDE_3_5_SONNET'),
os.getenv('CLAUDE_3_5_HAIKU'),
os.getenv('CLAUDE_3_7_SONNET'),
os.getenv('CLAUDE_3_OPUS'),
os.getenv('CLAUDE_3_HAIKU'),
]
STANDARD_MODEL = "gemini-2.0-flash"
QUICK_MODEL = "gemini-2.0-flash"
REASONING_MODEL = "gemini-2.0-flash"
DEFAULT_SYSTEM_PROMPT = """Bạn là một trợ lý hữu ích. Hãy luôn trả lời bằng tiếng Việt."""
class QueryAgent:
def __init__(self):
# Initialize database connection
self.db = SQLDatabase.from_uri(
f"mysql+pymysql://{os.getenv('MYSQL_USER')}:{os.getenv('MYSQL_PASSWORD')}@"
f"{os.getenv('MYSQL_HOST')}:{os.getenv('MYSQL_PORT')}/{os.getenv('MYSQL_DB')}"
)
# Load templates
with open("templates/sql_templates.json", "r", encoding="utf-8") as f:
self.templates = json.load(f)
# Initialize LLM
# CLAUDE_3_5_SONNET = os.getenv('CLAUDE_3_5_SONNET')
# Lấy model name từ biến môi trường, ví dụ: 'gemini-pro'
GEMINI_MODEL = "gemini-2.0-flash"
GEMINI_API_KEY = os.getenv('GOOGLE_API_KEY')
# Khởi tạo Gemini
self.llm = ChatGoogleGenerativeAI(
model=GEMINI_MODEL,
google_api_key=GEMINI_API_KEY,
temperature=0
)
# Initialize SQL toolkit and get tools
self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
self.tools = self.toolkit.get_tools()
# Load system prompt for the agent
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
# Add custom instructions to the prompt template
today = datetime.now().strftime("%Y-%m-%d")
additional_instructions = f"""
Additional Instructions:
1. When dealing with financial ratios like ROE, ROA, or Net Income Margin (NIM), they are stored as percentages but in decimal form (E.g: 15% is stored as 0.15) in the database.
2. Format your responses with appropriate units (Bn. VND for money, % for ratios).
3. Khi user hỏi về 1 ngành nào đó, tên ngành chưa chắc đúng, tìm và match với ngành gần nhất trong bảng vn100_listing_by_industry
4. Các câu hỏi về "top" cần trả vể kết quả unique (ví dụ: "Top 3 công ty có lợi nhuận cao nhất" phải trả về 3 công ty, không được trả về 1 công ty)
Today is (YYYY-MM-DD): {today}
Your final answer should be in Vietnamese.
"""
system_message = prompt_template.format(dialect="MySQL", top_k=5) + additional_instructions
# Create ReAct agent
self.agent_executor = create_react_agent(
model=self.llm,
tools=self.tools,
prompt=system_message
)
def match_template(self, question: str):
"""Return SQL if user question closely matches a predefined template"""
questions = [t['question'] for t in self.templates]
match = difflib.get_close_matches(question.strip(), questions, n=1, cutoff=0.85)
if match:
for t in self.templates:
if t["question"] == match[0]:
return t["sql"]
return None
def query(self, question: str, history: Optional[List[dict]] = None) -> str:
"""Template-based if match, else fallback to LLM"""
matched_sql = self.match_template(question)
if matched_sql:
print(f"\n⚡ Using template for: {question}")
try:
result = self.db.run(matched_sql)
return f"[Kết quả từ template SQL]\n{result}"
except Exception as e:
return f"⚠️ Lỗi khi chạy SQL template: {str(e)}"
print(f"\n🤖 Using LLM for: {question}")
agent_input = {"messages": history + [{"role": "user", "content": question}] if history else [{"role": "user", "content": question}]}
response = []
for step in self.agent_executor.stream(agent_input, stream_mode="values"):
print("\nStep:", step["messages"][-1])
response = step["messages"][-1]
return response.content if hasattr(response, 'content') else str(response)
agent = QueryAgent()
@app.get("/")
def root():
return {"response": """
Chào mừng bạn đến với Chatbot Rangdong!
Nhập 'quit' hoặc 'exit' để thoát ra khỏi chương trình.
Nhập 'reset' để bắt đàu cuộc trò chuyện mới
"""}
@app.post("/chat")
async def chat_endpoint(request: dict = Body(..., examples={"message": "Xin chào. Bạn có thể làm gì?"})):
global agent
user_message = request.get("message", "")
history = request.get("history",[])
if not user_message:
return {"error": "Không có tin nhắn."}
response = agent.query(user_message, history)
return {"response": response}