T-K-O-H
Fix tool imports and ResearchAgent initialization in main.py
f58c7f1
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List, Optional
import os
import json
from app.agent.agent import ResearchAgent
from app.agent.tools.semantic_search import SemanticSearchTool
from app.agent.tools.citation_analyzer import CitationAnalyzerTool
from app.agent.tools.workout_planner import WorkoutPlannerTool
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize tools and agent
semantic_search = SemanticSearchTool()
citation_analyzer = CitationAnalyzerTool()
workout_planner = WorkoutPlannerTool()
tools = [semantic_search, citation_analyzer, workout_planner]
agent = ResearchAgent(tools=tools, openai_api_key=os.getenv("OPENAI_API_KEY"))
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
stream: Optional[bool] = False
@app.post("/chat")
async def chat(request: ChatRequest):
try:
# Get the last message from the user
last_message = request.messages[-1]
if last_message.role != "user":
raise HTTPException(status_code=400, detail="Last message must be from user")
# Process the message with the agent
response = agent.process_message(last_message.content)
return JSONResponse(content={
"choices": [{
"message": {
"role": "assistant",
"content": response
}
}]
})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
try:
# Save the file
file_path = os.path.join("data/uploads", file.filename)
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
return JSONResponse(content={
"message": "File uploaded successfully",
"filename": file.filename
})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)