File size: 4,591 Bytes
0389a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END, START
from config import AGENT_MODEL
from utils.state import AgentState
from tools.rag import create_rag_tool
from langchain_core.messages import HumanMessage, AIMessage
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def create_agent_graph(vector_store, tavily_tool, arxiv_tool):
    """Create a simplified agent graph with direct tool invocation."""
    logging.info("Creating agent graph with tools")
    
    try:
        retriever = vector_store.as_retriever(search_kwargs={"k": 5})
        logging.info("Successfully created retriever")
    except Exception as e:
        logging.error(f"Error creating retriever: {e}")
        raise
        
    try:
        ai_rag_tool = create_rag_tool(retriever)
        logging.info("Successfully created RAG tool")
    except Exception as e:
        logging.error(f"Error creating RAG tool: {e}")
        raise
    
    # Create the LLM
    logging.info(f"Creating LLM with model: {AGENT_MODEL}")
    llm = ChatOpenAI(model=AGENT_MODEL, temperature=0)
    
    # Create simplified graph
    graph = StateGraph(AgentState)
    
    # Keywords that trigger RAG for our example document
    evaluation_keywords = [
        "ai", "artificial intelligence", "llm", "language model", "evaluation", 
        "metric", "metrics", "benchmark", "benchmarks", "mmlu", "humaneval", 
        "truthfulqa", "big-bench", "accuracy", "factual", "reasoning", "turing",
        "red team", "red teaming", "challenge", "challenges", "linguistic"
    ]
    
    def agent_node(state):
        """Process user input and decide which tool to use."""
        user_input = state["messages"][-1].content.lower()
        logging.info(f"Processing user input: {user_input}")
        
        # Simple rule-based tool selection
        response = ""
        try:
            if "search" in user_input:
                logging.info("Using search tool")
                try:
                    response = tavily_tool(user_input)
                except Exception as e:
                    logging.error(f"Search tool error: {e}")
                    response = f"I encountered an error with the search tool: {str(e)}"
            elif "arxiv" in user_input or "paper" in user_input or "academic" in user_input:
                logging.info("Using arxiv tool")
                try:
                    response = arxiv_tool(user_input)
                except Exception as e:
                    logging.error(f"ArXiv tool error: {e}")
                    response = f"I encountered an error with the ArXiv tool: {str(e)}"
            elif any(kw in user_input for kw in evaluation_keywords):
                logging.info("Using RAG tool for AI evaluation query")
                try:
                    result = ai_rag_tool.invoke(user_input)
                    if result and "response" in result:
                        response = result["response"]
                    else:
                        response = "I couldn't find specific information about that in our documents."
                except Exception as e:
                    logging.error(f"Error using RAG tool: {e}")
                    response = "I encountered an error when searching our documents. Let me provide a general response instead."
                    # Fall back to direct LLM if RAG fails
                    prompt = f"""
                    The user asked: "{user_input}"
                    
                    Provide a helpful response about LLM evaluation based on standard knowledge.
                    """
                    response = llm.invoke(prompt).content
            else:
                # Simple prompt for direct chat
                logging.info("Using direct LLM response")
                prompt = f"Answer this question concisely: {user_input}"
                response = llm.invoke(prompt).content
        except Exception as e:
            logging.error(f"Error processing request: {e}")
            response = f"I encountered an error processing your request: {str(e)}"
        
        logging.info(f"Generated response: {response[:100]}...")
        return {
            "messages": state["messages"] + [AIMessage(content=response)],
            "context": state.get("context", [])
        }
    
    # Add the agent node
    graph.add_node("agent", agent_node)
    
    # Set up the graph flow
    graph.add_edge(START, "agent")
    graph.add_edge("agent", END)
    
    return graph.compile()