from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import List, Optional import os import json from app.agent.agent import ResearchAgent from app.agent.tools.semantic_search import SemanticSearchTool from app.agent.tools.citation_analyzer import CitationAnalyzerTool from app.agent.tools.workout_planner import WorkoutPlannerTool app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize tools and agent semantic_search = SemanticSearchTool() citation_analyzer = CitationAnalyzerTool() workout_planner = WorkoutPlannerTool() tools = [semantic_search, citation_analyzer, workout_planner] agent = ResearchAgent(tools=tools, openai_api_key=os.getenv("OPENAI_API_KEY")) class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] stream: Optional[bool] = False @app.post("/chat") async def chat(request: ChatRequest): try: # Get the last message from the user last_message = request.messages[-1] if last_message.role != "user": raise HTTPException(status_code=400, detail="Last message must be from user") # Process the message with the agent response = agent.process_message(last_message.content) return JSONResponse(content={ "choices": [{ "message": { "role": "assistant", "content": response } }] }) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/upload") async def upload_file(file: UploadFile = File(...)): try: # Save the file file_path = os.path.join("data/uploads", file.filename) with open(file_path, "wb") as buffer: content = await file.read() buffer.write(content) return JSONResponse(content={ "message": "File uploaded successfully", "filename": file.filename }) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)