Spaces:
No application file
No application file
| """ | |
| File call_llm.py - LLM node với function calling | |
| """ | |
| import json | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage | |
| from langchain_core.tools import tool | |
| from src.state.graph_state import TransportationState | |
| from src.config.setting import settings | |
| def predict_transportation_mode(country: str, pack_price: float, project_code: str, vendor: str) -> dict: | |
| """Dự đoán phương thức vận chuyển tối ưu""" | |
| from src.app.schema.transportation import TransportationRequest | |
| from src.app.api.predict import predict_transportation | |
| request = TransportationRequest( | |
| country=country, | |
| pack_price=pack_price, | |
| project_code=project_code, | |
| vendor=vendor | |
| ) | |
| response = predict_transportation(request) | |
| return response.dict() | |
| def get_transportation_options() -> dict: | |
| """Lấy danh sách options vận chuyển""" | |
| from src.app.api.predict import get_transportation_options as get_options | |
| return get_options() | |
| class CallLLMNode: | |
| def __init__(self): | |
| self.llm = ChatGoogleGenerativeAI( | |
| model="gemini-1.5-pro", | |
| google_api_key=settings.GEMINI_API_KEY | |
| ) | |
| self.llm_with_tools = self.llm.bind_tools([ | |
| predict_transportation_mode, | |
| get_transportation_options | |
| ]) | |
| def _create_system_message(self) -> str: | |
| """System prompt cho transportation AI""" | |
| return """Bạn là chuyên gia logistics và vận chuyển. Bạn có 2 functions: | |
| 1. predict_transportation_mode(country, pack_price, project_code, vendor) | |
| - Dự đoán phương thức vận chuyển tối ưu | |
| - Cần đủ 4 tham số | |
| 2. get_transportation_options() | |
| - Lấy danh sách options có sẵn | |
| - Không cần tham số | |
| QUY TẮC: | |
| - "dự đoán", "tối ưu", "nên chọn" → gọi predict_transportation_mode | |
| - "options", "danh sách", "có gì" → gọi get_transportation_options | |
| - Thiếu thông tin → hỏi ngược ngay | |
| - Trả lời luật/quy định từ kiến thức""" | |
| def __call__(self, state: TransportationState) -> TransportationState: | |
| try: | |
| messages = [ | |
| SystemMessage(content=self._create_system_message()), | |
| HumanMessage(content=state["user_message"]) | |
| ] | |
| response = self.llm_with_tools.invoke(messages) | |
| if response.tool_calls: | |
| print(f"✅ Function calls detected: {len(response.tool_calls)}") | |
| # Execute tools và lưu results | |
| tool_messages = [] | |
| for tool_call in response.tool_calls: | |
| func_name = tool_call["name"] | |
| args = tool_call["args"] | |
| print(f"🔧 Calling {func_name} with args: {args}") | |
| if func_name == "predict_transportation_mode": | |
| result = predict_transportation_mode.invoke(args) | |
| elif func_name == "get_transportation_options": | |
| result = get_transportation_options.invoke({}) | |
| else: | |
| result = {"error": f"Unknown function: {func_name}"} | |
| print(f"📊 Result: {result}") | |
| # Save function call info | |
| state["function_calls_made"].append({ | |
| "function_name": func_name, | |
| "arguments": args, | |
| "result": result | |
| }) | |
| tool_messages.append(ToolMessage( | |
| content=json.dumps(result), | |
| tool_call_id=tool_call["id"] | |
| )) | |
| # Get final response | |
| final_messages = messages + [response] + tool_messages | |
| final_response = self.llm.invoke(final_messages) | |
| state["ai_response"] = final_response.content | |
| print("✅ Function calling completed successfully") | |
| else: | |
| state["ai_response"] = response.content | |
| print("ℹ️ No function calls needed") | |
| state["current_step"] = "completed" | |
| return state | |
| except Exception as e: | |
| print(f"❌ Error in LLM call: {e}") | |
| state["error_message"] = f"LLM Error: {str(e)}" | |
| state["current_step"] = "error" | |
| return state | |
| def create_call_llm_node() -> CallLLMNode: | |
| """Factory function to create CallLLMNode""" | |
| return CallLLMNode() | |