""" File call_llm.py - LLM node với function calling """ import json from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage from langchain_core.tools import tool from src.state.graph_state import TransportationState from src.config.setting import settings @tool def predict_transportation_mode(country: str, pack_price: float, project_code: str, vendor: str) -> dict: """Dự đoán phương thức vận chuyển tối ưu""" from src.app.schema.transportation import TransportationRequest from src.app.api.predict import predict_transportation request = TransportationRequest( country=country, pack_price=pack_price, project_code=project_code, vendor=vendor ) response = predict_transportation(request) return response.dict() @tool def get_transportation_options() -> dict: """Lấy danh sách options vận chuyển""" from src.app.api.predict import get_transportation_options as get_options return get_options() class CallLLMNode: def __init__(self): self.llm = ChatGoogleGenerativeAI( model="gemini-1.5-pro", google_api_key=settings.GEMINI_API_KEY ) self.llm_with_tools = self.llm.bind_tools([ predict_transportation_mode, get_transportation_options ]) def _create_system_message(self) -> str: """System prompt cho transportation AI""" return """Bạn là chuyên gia logistics và vận chuyển. Bạn có 2 functions: 1. predict_transportation_mode(country, pack_price, project_code, vendor) - Dự đoán phương thức vận chuyển tối ưu - Cần đủ 4 tham số 2. get_transportation_options() - Lấy danh sách options có sẵn - Không cần tham số QUY TẮC: - "dự đoán", "tối ưu", "nên chọn" → gọi predict_transportation_mode - "options", "danh sách", "có gì" → gọi get_transportation_options - Thiếu thông tin → hỏi ngược ngay - Trả lời luật/quy định từ kiến thức""" def __call__(self, state: TransportationState) -> TransportationState: try: messages = [ SystemMessage(content=self._create_system_message()), HumanMessage(content=state["user_message"]) ] response = self.llm_with_tools.invoke(messages) if response.tool_calls: print(f"✅ Function calls detected: {len(response.tool_calls)}") # Execute tools và lưu results tool_messages = [] for tool_call in response.tool_calls: func_name = tool_call["name"] args = tool_call["args"] print(f"🔧 Calling {func_name} with args: {args}") if func_name == "predict_transportation_mode": result = predict_transportation_mode.invoke(args) elif func_name == "get_transportation_options": result = get_transportation_options.invoke({}) else: result = {"error": f"Unknown function: {func_name}"} print(f"📊 Result: {result}") # Save function call info state["function_calls_made"].append({ "function_name": func_name, "arguments": args, "result": result }) tool_messages.append(ToolMessage( content=json.dumps(result), tool_call_id=tool_call["id"] )) # Get final response final_messages = messages + [response] + tool_messages final_response = self.llm.invoke(final_messages) state["ai_response"] = final_response.content print("✅ Function calling completed successfully") else: state["ai_response"] = response.content print("ℹ️ No function calls needed") state["current_step"] = "completed" return state except Exception as e: print(f"❌ Error in LLM call: {e}") state["error_message"] = f"LLM Error: {str(e)}" state["current_step"] = "error" return state def create_call_llm_node() -> CallLLMNode: """Factory function to create CallLLMNode""" return CallLLMNode()