File size: 8,407 Bytes
a554427 e276667 a554427 0d9ff86 a554427 0d9ff86 a554427 a129ce3 a554427 32bde24 a554427 1ca8d94 a554427 1ca8d94 a554427 25f72e5 eb06bfe 4e760a5 eb06bfe a554427 e276667 6f6ed4a 0d9ff86 e276667 a554427 0d9ff86 a554427 0d9ff86 a554427 7957cfa 6dac955 7957cfa 6dac955 7957cfa a554427 1ca8d94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | import os
import base64
import uvicorn
from typing import Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import create_agent
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import MemorySaver
memory_saver = MemorySaver()
# --- GLOBAL STATE ---
MCP_URL = "https://Codemaster67-ResearchPaperMCP.hf.space/mcp"
mcp_tools = []
agent_executor = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Fetch tool definitions from HF once when the server starts."""
global mcp_tools
try:
client = MultiServerMCPClient({
"ResearchAgent": { "url": MCP_URL, "transport": "http" }
})
mcp_tools = await client.get_tools()
print(f"✅ Tools connected: {len(mcp_tools)}")
except Exception as e:
print(f"❌ MCP Connection Failed: {e}")
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://research-agent-heduc5oop-research-paper-agent.vercel.app",
"http://localhost:3000",
"http://localhost:5173",
],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# --- API ENDPOINTS ---
@app.post("/initialize")
async def initialize_agent(api_key: str = Form(...), model_name: str = Form(...)):
"""
Creates the agent ONE TIME.
The frontend calls this once when the user submits their settings.
"""
global agent_executor, mcp_tools
try:
# Setup the LLM
llm = ChatGoogleGenerativeAI(
model=model_name,
google_api_key=api_key,
temperature=0.1
)
llm.invoke("Say cheese") # test to see if api key is valid
system_prompt = (
"You are an expert academic and professional research assistant. Your primary goal is to provide accurate, "
"comprehensive, and evidence-based answers by effectively utilizing your available tools.\n\n"
"### CORE BEHAVIORS\n"
"- Accuracy over Guesswork: Do not hallucinate. If you do not know the answer or your tools yield no relevant results, explicitly state that you could not find the information.\n"
"- Synthesis: When gathering information from multiple sources, synthesize the findings into a coherent, easy-to-read response rather than just dumping raw summaries.\n\n"
"### TOOL USAGE\n"
"- Always prioritize using your tools to fetch up-to-date, peer-reviewed, or factual data before attempting to answer from your internal knowledge.\n"
"- If a tool query fails or returns insufficient data, try reformulating your search terms or taking a different analytical angle before giving up.\n\n"
"### CRITICAL CITATION RULES\n"
"You must strictly adhere to the following citation format whenever you use a tool to retrieve information. Failure to do so is unacceptable:\n"
"1. Inline Citations: Use bracketed numbers (e.g., [1], [2]) immediately following the specific claim or fact they support.\n"
"2. Consistent Numbering: Each unique source must have a unique number. If you reference the same source multiple times, reuse its original number.\n"
"3. Mandatory References Section: Append a distinct 'Sources & References' heading at the very bottom of your response.\n"
"4. Strict Line Separation: In the references section, EVERY source must be on its own dedicated line. NEVER group multiple URLs on the same line.\n"
"5. Reference Format: Use the format `[Number] Title or Source Name - URL` (e.g., `[1] Official Documentation - https://example.com`).\n"
)
# Create the Agent and store it globally
agent_executor = create_agent(
llm,
mcp_tools,
system_prompt = system_prompt,
checkpointer=memory_saver
)
return {"status": "Success", "message": f"Agent initialized with {model_name}"}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Initialization failed: {str(e)}")
@app.post("/chat")
async def chat(
message: str = Form(...),
session_id: str = Form("default_thread"),
file: Optional[UploadFile] = File(None)
):
global agent_executor
if agent_executor is None:
raise HTTPException(status_code=400, detail="Agent not initialized.")
message_content = [{"type": "text", "text": message}]
if file:
file_bytes = await file.read()
encoded_file = base64.b64encode(file_bytes).decode("utf-8")
message_content.append({
"type": "media",
"mime_type": file.content_type,
"data": encoded_file
})
try:
inputs = {"messages": [HumanMessage(content=message_content)]}
# --- NEW: PASS THREAD ID IN CONFIG ---
config = {"configurable": {"thread_id": session_id}}
response = await agent_executor.ainvoke(inputs, config=config)
# --- NEW: PRINT TOOL CALLS ---
for msg in response["messages"]:
# Check if this message contains tool calls
if hasattr(msg, "tool_calls") and msg.tool_calls:
for tool_call in msg.tool_calls:
print(f"[TOOL CALL]: {tool_call['name']}")
print(f"ARGUMENTS]: {tool_call['args']}\n")
final_answer = ""
# Loop backwards to find the last assistant message with content
for msg in reversed(response["messages"]):
if msg.content:
if isinstance(msg.content, str):
final_answer = msg.content
elif isinstance(msg.content, list):
final_answer = " ".join([
part.get("text", "")
for part in msg.content
if isinstance(part, dict) and "text" in part
])
if final_answer.strip():
break
return {"response": final_answer}
except Exception as e:
print(f"❌ Agent Error: {str(e)}") # Added print for error visibility
return {"error": f"Agent Error: {str(e)}"}
@app.delete("/session/{session_id}")
async def delete_session(session_id: str):
"""
Clears the LangGraph MemorySaver checkpoints for a specific session/thread ID
when the user disconnects or refreshes the page.
"""
global memory_saver
if memory_saver is None:
return {"status": "Ignored", "message": "No memory saver initialized."}
try:
# MemorySaver in newer versions of LangGraph uses 'checkpoints' and 'writes' dicts.
# The keys are typically tuples where the first element is the thread_id.
if hasattr(memory_saver, 'checkpoints'):
keys_to_delete = [k for k in memory_saver.checkpoints.keys() if k[0] == session_id]
for k in keys_to_delete:
del memory_saver.checkpoints[k]
if hasattr(memory_saver, 'writes'):
keys_to_delete = [k for k in memory_saver.writes.keys() if k[0] == session_id]
for k in keys_to_delete:
del memory_saver.writes[k]
if hasattr(memory_saver, 'storage') and session_id in memory_saver.storage:
del memory_saver.storage[session_id]
print(f"Session {session_id} successfully cleared from memory.")
return {"status": "Success", "message": f"Session {session_id} cleared."}
except Exception as e:
print(f"Error clearing session {session_id}: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to clear session")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |