File size: 6,470 Bytes
10e9b7d
4e95936
eccf8e4
4e95936
 
 
3c4371f
10e9b7d
4e95936
3c4371f
4e95936
 
 
 
 
 
 
 
3c4371f
4e95936
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eccf8e4
4e95936
7d65c66
4e95936
 
 
7d65c66
 
3c4371f
31243f4
 
 
 
4e95936
31243f4
4e95936
7d65c66
4e95936
31243f4
4e95936
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import tempfile
import requests
from typing import Dict, Any, Annotated
from typing_extensions import TypedDict
import gradio as gr
import pandas as pd

# Your constants + imports stay

# New imports for the stack
from smolagents import CodeAgent, HfApiModel  # Smolagents for code/web agents
from smolagents.tools import DuckDuckGoSearchResults  # Built-in web tool
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage
from transformers import pipeline  # For lightweight LLM routing

# --- Enhanced Agent with LangGraph + Smolagents ---
class CrmAgent:
    def __init__(self):
        print("CrmAgent initialized with LangGraph + Smolagents.")
        # Lightweight router LLM (free HF inference)
        self.router = pipeline("text-generation", model="gpt2", device=-1)  # CPU for hack
        # Smolagents CodeAgent with web tool
        self.llm = HfApiModel(model_id="microsoft/DialoGPT-medium")  # Free HF model
        search_tool = DuckDuckGoSearchResults(num_results=3)  # Quick web hits
        self.code_agent = CodeAgent(llm=self.llm, tools=[search_tool])
        # Temp dir for files
        self.temp_dir = tempfile.mkdtemp()

    # Tool: Download file if needed (GAIA questions may have attachments)
    @tool
    def download_file(self, task_id: str) -> str:
        """Downloads file for task_id if exists, returns path."""
        url = f"{DEFAULT_API_URL}/files/{task_id}"
        try:
            resp = requests.get(url, timeout=10)
            if resp.status_code == 200:
                file_path = os.path.join(self.temp_dir, f"{task_id}_file")
                with open(file_path, "wb") as f:
                    f.write(resp.content)
                return f"File downloaded: {file_path}"
            return "No file found."
        except Exception as e:
            return f"Download error: {e}"

    # Router Node: Decide path with LLM
    def router_node(self, state: Dict[str, Any]) -> Dict[str, str]:
        question = state["question"]
        prompt = f"Given question: '{question[:100]}...'. Respond with route: 'search' if needs web info, 'code' if math/file/code, 'both' if both, 'direct' if obvious."
        response = self.router(prompt, max_length=20, num_return_sequences=1)[0]["generated_text"]
        route = response.strip().lower().split()[-1]  # Crude parse, tweak as needed
        state["route"] = route
        print(f"Routed to: {route}")
        return state

    # Search Node: Use smolagents web
    def search_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
        question = state["question"]
        try:
            # Smolagents call (it handles tool selection internally)
            result = self.code_agent.run(question)  # Runs code/web as needed
            state["search_results"] = result
            print(f"Search/code output: {result[:100]}...")
        except Exception as e:
            state["search_results"] = f"Error: {e}"
        return state

    # Direct Node: Simple guess or pass
    def direct_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
        # Fallback: Basic heuristic or empty
        state["final_answer"] = "Direct answer needed—implement heuristic here."
        return state

    # Conditional Edge: Based on route
    def conditional_route(self, state: Dict[str, Any]) -> str:
        route = state.get("route", "direct")
        if route in ["search", "both"]:
            return "search"
        elif route == "code":
            return "search"  # Smolagents handles code too
        return "direct"

    # Build the Graph
    def build_graph(self):
        # State
        class AgentState(TypedDict):
            question: str
            route: str
            search_results: str
            final_answer: str

        # Graph
        workflow = StateGraph(AgentState)
        workflow.add_node("router", self.router_node)
        workflow.add_node("search", self.search_node)
        workflow.add_node("direct", self.direct_node)

        # Edges
        workflow.set_entry_point("router")
        workflow.add_conditional_edges("router", self.conditional_route, {"search": "search", "direct": "direct"})
        workflow.add_edge("search", END)
        workflow.add_edge("direct", END)

        # Compile
        self.graph = workflow.compile()

    def __call__(self, question: str, task_id: str = None) -> str:
        if not hasattr(self, "graph"):
            self.build_graph()
        # Download file if task_id
        if task_id:
            file_info = self.download_file.invoke({"task_id": task_id})
            question += f" [File info: {file_info}]"  # Append to prompt

        # Run graph
        initial_state = {"question": question, "route": "", "search_results": "", "final_answer": ""}
        final_state = self.graph.invoke(initial_state)
        
        # Extract clean answer (smolagents outputs code-thought → result)
        answer = final_state.get("search_results", final_state.get("final_answer", "No answer generated."))
        # Strip to exact (no extras)
        if "final answer" in answer.lower():
            answer = answer.split("final answer")[-1].strip().split()[0] if answer.split("final answer")[-1].strip() else answer
        print(f"Agent final: {answer}")
        return answer

# --- Update run_and_submit_all (minor tweak for task_id) ---
def run_and_submit_all(profile: gr.OAuthProfile | None):
    # ... (keep all your existing code up to agent init)
    
    # 1. Instantiate Agent
    try:
        agent = CrmAgent()  # Our new beast
    except Exception as e:
        # ...
    
    # 3. Run your Agent (pass task_id)
    results_log = []
    answers_payload = []
    print(f"Running agent on {len(questions_data)} questions...")
    for item in questions_data:
        task_id = item.get("task_id")
        question_text = item.get("question")
        if not task_id or question_text is None:
            # ...
        try:
            submitted_answer = agent(question_text, task_id)  # Pass task_id for files
            answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
            results_log.append({"Task ID": task_id, "Question": question_text[:50] + "...", "Submitted Answer": submitted_answer})
        except Exception as e:
            # ...
    
    # ... (rest unchanged—submit as before)