File size: 8,407 Bytes
a554427
 
 
 
 
 
 
 
 
 
 
e276667
a554427
0d9ff86
a554427
 
0d9ff86
a554427
 
a129ce3
a554427
 
 
 
 
 
 
 
 
 
 
32bde24
a554427
 
 
 
 
 
 
 
 
 
 
 
1ca8d94
 
 
 
 
 
 
a554427
 
1ca8d94
a554427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f72e5
 
eb06bfe
4e760a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb06bfe
 
a554427
 
e276667
 
 
6f6ed4a
0d9ff86
e276667
a554427
 
 
 
 
 
 
 
 
 
0d9ff86
a554427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d9ff86
 
 
 
a554427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7957cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dac955
7957cfa
 
 
6dac955
7957cfa
 
a554427
1ca8d94
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import base64
import uvicorn
from typing import Optional

from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware

from contextlib import asynccontextmanager
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import create_agent
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import MemorySaver


memory_saver = MemorySaver()

# --- GLOBAL STATE ---
MCP_URL = "https://Codemaster67-ResearchPaperMCP.hf.space/mcp"

mcp_tools = []
agent_executor = None 


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Fetch tool definitions from HF once when the server starts."""
    global mcp_tools
    try:
        client = MultiServerMCPClient({
            "ResearchAgent": { "url": MCP_URL, "transport": "http" }
        })
        mcp_tools = await client.get_tools()
        print(f"✅ Tools connected: {len(mcp_tools)}")
    except Exception as e:
        print(f"❌ MCP Connection Failed: {e}")
    yield


app = FastAPI(lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "https://research-agent-heduc5oop-research-paper-agent.vercel.app",
        "http://localhost:3000",   
        "http://localhost:5173",  
    ],
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

# --- API ENDPOINTS ---

@app.post("/initialize")
async def initialize_agent(api_key: str = Form(...), model_name: str = Form(...)):
    """
    Creates the agent ONE TIME. 
    The frontend calls this once when the user submits their settings.
    """
    global agent_executor, mcp_tools
    
    try:
        # Setup the LLM
        llm = ChatGoogleGenerativeAI(
            model=model_name,
            google_api_key=api_key,
            temperature=0.1
        )
        llm.invoke("Say cheese") # test to see if api key is valid
        
        system_prompt = (
                    "You are an expert academic and professional research assistant. Your primary goal is to provide accurate, "
                    "comprehensive, and evidence-based answers by effectively utilizing your available tools.\n\n"
                    
                    "### CORE BEHAVIORS\n"
                    "- Accuracy over Guesswork: Do not hallucinate. If you do not know the answer or your tools yield no relevant results, explicitly state that you could not find the information.\n"
                    "- Synthesis: When gathering information from multiple sources, synthesize the findings into a coherent, easy-to-read response rather than just dumping raw summaries.\n\n"
                    
                    "### TOOL USAGE\n"
                    "- Always prioritize using your tools to fetch up-to-date, peer-reviewed, or factual data before attempting to answer from your internal knowledge.\n"
                    "- If a tool query fails or returns insufficient data, try reformulating your search terms or taking a different analytical angle before giving up.\n\n"
                    
                    "### CRITICAL CITATION RULES\n"
                    "You must strictly adhere to the following citation format whenever you use a tool to retrieve information. Failure to do so is unacceptable:\n"
                    "1. Inline Citations: Use bracketed numbers (e.g., [1], [2]) immediately following the specific claim or fact they support.\n"
                    "2. Consistent Numbering: Each unique source must have a unique number. If you reference the same source multiple times, reuse its original number.\n"
                    "3. Mandatory References Section: Append a distinct 'Sources & References' heading at the very bottom of your response.\n"
                    "4. Strict Line Separation: In the references section, EVERY source must be on its own dedicated line. NEVER group multiple URLs on the same line.\n"
                    "5. Reference Format: Use the format `[Number] Title or Source Name - URL` (e.g., `[1] Official Documentation - https://example.com`).\n"
                )

        
        # Create the Agent and store it globally
        agent_executor = create_agent(
            llm,                        
            mcp_tools,                  
            system_prompt = system_prompt,
            checkpointer=memory_saver
            )
        
        return {"status": "Success", "message": f"Agent initialized with {model_name}"}
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Initialization failed: {str(e)}")



@app.post("/chat")
async def chat(
    message: str = Form(...),
    session_id: str = Form("default_thread"), 
    file: Optional[UploadFile] = File(None)
):
    global agent_executor
    
    if agent_executor is None:
        raise HTTPException(status_code=400, detail="Agent not initialized.")

    message_content = [{"type": "text", "text": message}]
    if file:
        file_bytes = await file.read()
        encoded_file = base64.b64encode(file_bytes).decode("utf-8")
        message_content.append({
            "type": "media",
            "mime_type": file.content_type,
            "data": encoded_file
        })

    try:
        inputs = {"messages": [HumanMessage(content=message_content)]}
        
        # --- NEW: PASS THREAD ID IN CONFIG ---
        config = {"configurable": {"thread_id": session_id}}
        response = await agent_executor.ainvoke(inputs, config=config)
        
        # --- NEW: PRINT TOOL CALLS ---
        for msg in response["messages"]:
            # Check if this message contains tool calls
            if hasattr(msg, "tool_calls") and msg.tool_calls:
                for tool_call in msg.tool_calls:
                    print(f"[TOOL CALL]: {tool_call['name']}")
                    print(f"ARGUMENTS]: {tool_call['args']}\n")

        final_answer = ""
        # Loop backwards to find the last assistant message with content
        for msg in reversed(response["messages"]):
            if msg.content:
                if isinstance(msg.content, str):
                    final_answer = msg.content
                elif isinstance(msg.content, list):
                    final_answer = " ".join([
                        part.get("text", "") 
                        for part in msg.content 
                        if isinstance(part, dict) and "text" in part
                    ])
                
                if final_answer.strip():
                    break
        
        return {"response": final_answer}
    except Exception as e:
        print(f"❌ Agent Error: {str(e)}") # Added print for error visibility
        return {"error": f"Agent Error: {str(e)}"}


@app.delete("/session/{session_id}")
async def delete_session(session_id: str):
    """
    Clears the LangGraph MemorySaver checkpoints for a specific session/thread ID
    when the user disconnects or refreshes the page.
    """
    global memory_saver
    
    if memory_saver is None:
        return {"status": "Ignored", "message": "No memory saver initialized."}
        
    try:
        # MemorySaver in newer versions of LangGraph uses 'checkpoints' and 'writes' dicts.
        # The keys are typically tuples where the first element is the thread_id.
        
        if hasattr(memory_saver, 'checkpoints'):
            keys_to_delete = [k for k in memory_saver.checkpoints.keys() if k[0] == session_id]
            for k in keys_to_delete:
                del memory_saver.checkpoints[k]
                
        if hasattr(memory_saver, 'writes'):
            keys_to_delete = [k for k in memory_saver.writes.keys() if k[0] == session_id]
            for k in keys_to_delete:
                del memory_saver.writes[k]
                
        if hasattr(memory_saver, 'storage') and session_id in memory_saver.storage:
            del memory_saver.storage[session_id]
            
        print(f"Session {session_id} successfully cleared from memory.")
        return {"status": "Success", "message": f"Session {session_id} cleared."}
        
    except Exception as e:
        print(f"Error clearing session {session_id}: {str(e)}")
        raise HTTPException(status_code=500, detail="Failed to clear session")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)