import os import base64 import uvicorn from typing import Optional from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_google_genai import ChatGoogleGenerativeAI from langchain.agents import create_agent from langchain_core.messages import HumanMessage from langgraph.checkpoint.memory import MemorySaver memory_saver = MemorySaver() # --- GLOBAL STATE --- MCP_URL = "https://Codemaster67-ResearchPaperMCP.hf.space/mcp" mcp_tools = [] agent_executor = None @asynccontextmanager async def lifespan(app: FastAPI): """Fetch tool definitions from HF once when the server starts.""" global mcp_tools try: client = MultiServerMCPClient({ "ResearchAgent": { "url": MCP_URL, "transport": "http" } }) mcp_tools = await client.get_tools() print(f"✅ Tools connected: {len(mcp_tools)}") except Exception as e: print(f"❌ MCP Connection Failed: {e}") yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=[ "https://research-agent-heduc5oop-research-paper-agent.vercel.app", "http://localhost:3000", "http://localhost:5173", ], allow_credentials=True, allow_methods=["GET", "POST"], allow_headers=["*"], ) # --- API ENDPOINTS --- @app.post("/initialize") async def initialize_agent(api_key: str = Form(...), model_name: str = Form(...)): """ Creates the agent ONE TIME. The frontend calls this once when the user submits their settings. """ global agent_executor, mcp_tools try: # Setup the LLM llm = ChatGoogleGenerativeAI( model=model_name, google_api_key=api_key, temperature=0.1 ) llm.invoke("Say cheese") # test to see if api key is valid system_prompt = ( "You are an expert academic and professional research assistant. Your primary goal is to provide accurate, " "comprehensive, and evidence-based answers by effectively utilizing your available tools.\n\n" "### CORE BEHAVIORS\n" "- Accuracy over Guesswork: Do not hallucinate. If you do not know the answer or your tools yield no relevant results, explicitly state that you could not find the information.\n" "- Synthesis: When gathering information from multiple sources, synthesize the findings into a coherent, easy-to-read response rather than just dumping raw summaries.\n\n" "### TOOL USAGE\n" "- Always prioritize using your tools to fetch up-to-date, peer-reviewed, or factual data before attempting to answer from your internal knowledge.\n" "- If a tool query fails or returns insufficient data, try reformulating your search terms or taking a different analytical angle before giving up.\n\n" "### CRITICAL CITATION RULES\n" "You must strictly adhere to the following citation format whenever you use a tool to retrieve information. Failure to do so is unacceptable:\n" "1. Inline Citations: Use bracketed numbers (e.g., [1], [2]) immediately following the specific claim or fact they support.\n" "2. Consistent Numbering: Each unique source must have a unique number. If you reference the same source multiple times, reuse its original number.\n" "3. Mandatory References Section: Append a distinct 'Sources & References' heading at the very bottom of your response.\n" "4. Strict Line Separation: In the references section, EVERY source must be on its own dedicated line. NEVER group multiple URLs on the same line.\n" "5. Reference Format: Use the format `[Number] Title or Source Name - URL` (e.g., `[1] Official Documentation - https://example.com`).\n" ) # Create the Agent and store it globally agent_executor = create_agent( llm, mcp_tools, system_prompt = system_prompt, checkpointer=memory_saver ) return {"status": "Success", "message": f"Agent initialized with {model_name}"} except Exception as e: raise HTTPException(status_code=400, detail=f"Initialization failed: {str(e)}") @app.post("/chat") async def chat( message: str = Form(...), session_id: str = Form("default_thread"), file: Optional[UploadFile] = File(None) ): global agent_executor if agent_executor is None: raise HTTPException(status_code=400, detail="Agent not initialized.") message_content = [{"type": "text", "text": message}] if file: file_bytes = await file.read() encoded_file = base64.b64encode(file_bytes).decode("utf-8") message_content.append({ "type": "media", "mime_type": file.content_type, "data": encoded_file }) try: inputs = {"messages": [HumanMessage(content=message_content)]} # --- NEW: PASS THREAD ID IN CONFIG --- config = {"configurable": {"thread_id": session_id}} response = await agent_executor.ainvoke(inputs, config=config) # --- NEW: PRINT TOOL CALLS --- for msg in response["messages"]: # Check if this message contains tool calls if hasattr(msg, "tool_calls") and msg.tool_calls: for tool_call in msg.tool_calls: print(f"[TOOL CALL]: {tool_call['name']}") print(f"ARGUMENTS]: {tool_call['args']}\n") final_answer = "" # Loop backwards to find the last assistant message with content for msg in reversed(response["messages"]): if msg.content: if isinstance(msg.content, str): final_answer = msg.content elif isinstance(msg.content, list): final_answer = " ".join([ part.get("text", "") for part in msg.content if isinstance(part, dict) and "text" in part ]) if final_answer.strip(): break return {"response": final_answer} except Exception as e: print(f"❌ Agent Error: {str(e)}") # Added print for error visibility return {"error": f"Agent Error: {str(e)}"} @app.delete("/session/{session_id}") async def delete_session(session_id: str): """ Clears the LangGraph MemorySaver checkpoints for a specific session/thread ID when the user disconnects or refreshes the page. """ global memory_saver if memory_saver is None: return {"status": "Ignored", "message": "No memory saver initialized."} try: # MemorySaver in newer versions of LangGraph uses 'checkpoints' and 'writes' dicts. # The keys are typically tuples where the first element is the thread_id. if hasattr(memory_saver, 'checkpoints'): keys_to_delete = [k for k in memory_saver.checkpoints.keys() if k[0] == session_id] for k in keys_to_delete: del memory_saver.checkpoints[k] if hasattr(memory_saver, 'writes'): keys_to_delete = [k for k in memory_saver.writes.keys() if k[0] == session_id] for k in keys_to_delete: del memory_saver.writes[k] if hasattr(memory_saver, 'storage') and session_id in memory_saver.storage: del memory_saver.storage[session_id] print(f"Session {session_id} successfully cleared from memory.") return {"status": "Success", "message": f"Session {session_id} cleared."} except Exception as e: print(f"Error clearing session {session_id}: {str(e)}") raise HTTPException(status_code=500, detail="Failed to clear session") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)