File size: 2,103 Bytes
1f725d8
5551822
 
1f725d8
 
 
 
5551822
 
 
 
 
 
1f725d8
 
5551822
1f725d8
5551822
 
1f725d8
5551822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f725d8
5551822
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import logging
import json
import re
from src.MultiRag.models.rag_model import State
from utils.asyncHandler import asyncHandler
from src.MultiRag.llm.llm_loader import llm
from src.MultiRag.prompts.prompt_templates import CHAT_PROMPT
from langchain_core.messages import SystemMessage, AIMessage
from src.MultiRag.tools.web_search import WebSearch

web_search_tool = WebSearch().search



@asyncHandler
async def chat_node(state: State):
    logging.info("Executing chat node...")
    tool_limit_hit = state.get("jump_to") == "end"
    has_context = len(state.get("worker_result", [])) > 0 or len(state.get("evidence", [])) > 0

    is_greeting = False
    if not has_context and len(state.get('messages', [])) > 0:
        last_human_msg = state.get('messages')[-1].content.lower()
        if last_human_msg in ["hi", "hello", "hey", "how are you", "who are you"]:
            is_greeting = True

    if tool_limit_hit or has_context or is_greeting:
        if has_context:
            logging.info("Context found from workers. Disabling web search to prevent redundant searches.")
        elif is_greeting:
            logging.info("Greeting detected. Disabling tools for natural conversation.")
        else:
            logging.info("Tool call limit hit. Invoking LLM without tools.")
        chat_llm = llm
    else:
        logging.info("Binding chat LLM with web search tool (limit check enabled)")
        chat_llm = llm.bind_tools([web_search_tool])
    prompt = [
        SystemMessage(content=CHAT_PROMPT + "\nIMPORTANT: Do NOT write JSON tool calls manually. If you want to use a tool, use the native tool-calling function. If you are just chatting or greeting, respond only in plain, friendly Markdown text.")
    ] + state.get('messages', [])
    
    if prompt:
        last_msg = prompt[-1]
        logging.info(f"Last message in prompt: {last_msg.content[:200]}...")

    logging.info("Invoking chat LLM...")
    res = await chat_llm.ainvoke(prompt)
  

    logging.info(f"Response retrieved from chat_llm: {res.content if res.content else 'Tool Call'}")
    return {"messages": [res]}