Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Body | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage | |
| from langchain_openai import ChatOpenAI | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_community.utilities import SQLDatabase | |
| from langgraph.prebuilt import create_react_agent | |
| from langchain import hub | |
| from langchain_community.agent_toolkits import SQLDatabaseToolkit | |
| from typing import List, Optional | |
| import json | |
| import difflib | |
| import os | |
| import sys | |
| from utils import * | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| app = FastAPI() | |
| # Code | |
| GEMINI_MODELS = ["gemini-2.0-flash"] | |
| OPEN_AI_MODELS = ["gpt-4o", "gpt-4o-mini", "o3-mini"] | |
| ANTHROPIC_AI_MODELS = [ | |
| os.getenv('CLAUDE_3_5_SONNET'), | |
| os.getenv('CLAUDE_3_5_HAIKU'), | |
| os.getenv('CLAUDE_3_7_SONNET'), | |
| os.getenv('CLAUDE_3_OPUS'), | |
| os.getenv('CLAUDE_3_HAIKU'), | |
| ] | |
| STANDARD_MODEL = "gemini-2.0-flash" | |
| QUICK_MODEL = "gemini-2.0-flash" | |
| REASONING_MODEL = "gemini-2.0-flash" | |
| DEFAULT_SYSTEM_PROMPT = """Bạn là một trợ lý hữu ích. Hãy luôn trả lời bằng tiếng Việt.""" | |
| class QueryAgent: | |
| def __init__(self): | |
| # Initialize database connection | |
| self.db = SQLDatabase.from_uri( | |
| f"mysql+pymysql://{os.getenv('MYSQL_USER')}:{os.getenv('MYSQL_PASSWORD')}@" | |
| f"{os.getenv('MYSQL_HOST')}:{os.getenv('MYSQL_PORT')}/{os.getenv('MYSQL_DB')}" | |
| ) | |
| # Load templates | |
| with open("templates/sql_templates.json", "r", encoding="utf-8") as f: | |
| self.templates = json.load(f) | |
| # Initialize LLM | |
| # CLAUDE_3_5_SONNET = os.getenv('CLAUDE_3_5_SONNET') | |
| # Lấy model name từ biến môi trường, ví dụ: 'gemini-pro' | |
| GEMINI_MODEL = "gemini-2.0-flash" | |
| GEMINI_API_KEY = os.getenv('GOOGLE_API_KEY') | |
| # Khởi tạo Gemini | |
| self.llm = ChatGoogleGenerativeAI( | |
| model=GEMINI_MODEL, | |
| google_api_key=GEMINI_API_KEY, | |
| temperature=0 | |
| ) | |
| # Initialize SQL toolkit and get tools | |
| self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm) | |
| self.tools = self.toolkit.get_tools() | |
| # Load system prompt for the agent | |
| prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") | |
| # Add custom instructions to the prompt template | |
| today = datetime.now().strftime("%Y-%m-%d") | |
| additional_instructions = f""" | |
| Additional Instructions: | |
| 1. When dealing with financial ratios like ROE, ROA, or Net Income Margin (NIM), they are stored as percentages but in decimal form (E.g: 15% is stored as 0.15) in the database. | |
| 2. Format your responses with appropriate units (Bn. VND for money, % for ratios). | |
| 3. Khi user hỏi về 1 ngành nào đó, tên ngành chưa chắc đúng, tìm và match với ngành gần nhất trong bảng vn100_listing_by_industry | |
| 4. Các câu hỏi về "top" cần trả vể kết quả unique (ví dụ: "Top 3 công ty có lợi nhuận cao nhất" phải trả về 3 công ty, không được trả về 1 công ty) | |
| Today is (YYYY-MM-DD): {today} | |
| Your final answer should be in Vietnamese. | |
| """ | |
| system_message = prompt_template.format(dialect="MySQL", top_k=5) + additional_instructions | |
| # Create ReAct agent | |
| self.agent_executor = create_react_agent( | |
| model=self.llm, | |
| tools=self.tools, | |
| prompt=system_message | |
| ) | |
| def match_template(self, question: str): | |
| """Return SQL if user question closely matches a predefined template""" | |
| questions = [t['question'] for t in self.templates] | |
| match = difflib.get_close_matches(question.strip(), questions, n=1, cutoff=0.85) | |
| if match: | |
| for t in self.templates: | |
| if t["question"] == match[0]: | |
| return t["sql"] | |
| return None | |
| def query(self, question: str, history: Optional[List[dict]] = None) -> str: | |
| """Template-based if match, else fallback to LLM""" | |
| matched_sql = self.match_template(question) | |
| if matched_sql: | |
| print(f"\n⚡ Using template for: {question}") | |
| try: | |
| result = self.db.run(matched_sql) | |
| return f"[Kết quả từ template SQL]\n{result}" | |
| except Exception as e: | |
| return f"⚠️ Lỗi khi chạy SQL template: {str(e)}" | |
| print(f"\n🤖 Using LLM for: {question}") | |
| agent_input = {"messages": history + [{"role": "user", "content": question}] if history else [{"role": "user", "content": question}]} | |
| response = [] | |
| for step in self.agent_executor.stream(agent_input, stream_mode="values"): | |
| print("\nStep:", step["messages"][-1]) | |
| response = step["messages"][-1] | |
| return response.content if hasattr(response, 'content') else str(response) | |
| agent = QueryAgent() | |
| def root(): | |
| return {"response": """ | |
| Chào mừng bạn đến với Chatbot Rangdong! | |
| Nhập 'quit' hoặc 'exit' để thoát ra khỏi chương trình. | |
| Nhập 'reset' để bắt đàu cuộc trò chuyện mới | |
| """} | |
| async def chat_endpoint(request: dict = Body(..., examples={"message": "Xin chào. Bạn có thể làm gì?"})): | |
| global agent | |
| user_message = request.get("message", "") | |
| history = request.get("history",[]) | |
| if not user_message: | |
| return {"error": "Không có tin nhắn."} | |
| response = agent.query(user_message, history) | |
| return {"response": response} | |