File size: 4,785 Bytes
80dbe44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
File call_llm.py - LLM node với function calling
"""
import json

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool

from src.state.graph_state import TransportationState
from src.config.setting import settings

@tool
def predict_transportation_mode(country: str, pack_price: float, project_code: str, vendor: str) -> dict:
    """Dự đoán phương thức vận chuyển tối ưu"""
    from src.app.schema.transportation import TransportationRequest
    from src.app.api.predict import predict_transportation
    
    request = TransportationRequest(
        country=country,
        pack_price=pack_price,
        project_code=project_code,
        vendor=vendor
    )
    response = predict_transportation(request)
    return response.dict()

@tool
def get_transportation_options() -> dict:
    """Lấy danh sách options vận chuyển"""
    from src.app.api.predict import get_transportation_options as get_options
    return get_options()

class CallLLMNode:
    def __init__(self):
        self.llm = ChatGoogleGenerativeAI(
            model="gemini-1.5-pro",
            google_api_key=settings.GEMINI_API_KEY
        )
        self.llm_with_tools = self.llm.bind_tools([
            predict_transportation_mode,
            get_transportation_options
        ])
    
    def _create_system_message(self) -> str:
        """System prompt cho transportation AI"""
        return """Bạn là chuyên gia logistics và vận chuyển. Bạn có 2 functions:

1. predict_transportation_mode(country, pack_price, project_code, vendor)
   - Dự đoán phương thức vận chuyển tối ưu
   - Cần đủ 4 tham số

2. get_transportation_options()
   - Lấy danh sách options có sẵn
   - Không cần tham số

QUY TẮC:
- "dự đoán", "tối ưu", "nên chọn" → gọi predict_transportation_mode
- "options", "danh sách", "có gì" → gọi get_transportation_options  
- Thiếu thông tin → hỏi ngược ngay
- Trả lời luật/quy định từ kiến thức"""
    
    def __call__(self, state: TransportationState) -> TransportationState:
        try:
            messages = [
                SystemMessage(content=self._create_system_message()),
                HumanMessage(content=state["user_message"])
            ]
            
            response = self.llm_with_tools.invoke(messages)
            
            if response.tool_calls:
                print(f"✅ Function calls detected: {len(response.tool_calls)}")
                
                # Execute tools và lưu results
                tool_messages = []
                for tool_call in response.tool_calls:
                    func_name = tool_call["name"]
                    args = tool_call["args"]
                    
                    print(f"🔧 Calling {func_name} with args: {args}")
                    
                    if func_name == "predict_transportation_mode":
                        result = predict_transportation_mode.invoke(args)
                    elif func_name == "get_transportation_options":
                        result = get_transportation_options.invoke({})
                    else:
                        result = {"error": f"Unknown function: {func_name}"}
                    
                    print(f"📊 Result: {result}")
                    
                    # Save function call info
                    state["function_calls_made"].append({
                        "function_name": func_name,
                        "arguments": args,
                        "result": result
                    })
                    
                    tool_messages.append(ToolMessage(
                        content=json.dumps(result),
                        tool_call_id=tool_call["id"]
                    ))
                
                # Get final response
                final_messages = messages + [response] + tool_messages
                final_response = self.llm.invoke(final_messages)
                state["ai_response"] = final_response.content
                print("✅ Function calling completed successfully")
            else:
                state["ai_response"] = response.content
                print("ℹ️ No function calls needed")
            
            state["current_step"] = "completed"
            return state
            
        except Exception as e:
            print(f"❌ Error in LLM call: {e}")
            state["error_message"] = f"LLM Error: {str(e)}"
            state["current_step"] = "error"
            return state

def create_call_llm_node() -> CallLLMNode:
    """Factory function to create CallLLMNode"""
    return CallLLMNode()