Spaces:
Sleeping
Sleeping
| """FastAPI web application for the Agent Framework.""" | |
| import os | |
| import sys | |
| import uuid | |
| import shutil | |
| from pathlib import Path | |
| from typing import Optional, List | |
| from datetime import datetime | |
| # Add parent directory to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from dotenv import load_dotenv | |
| from agent_framework import ( | |
| Agent, LlmClient, InMemorySessionManager, | |
| display_trace, ExecutionContext, format_trace | |
| ) | |
| from agent_tools import calculator, search_web, read_file, list_files, unzip_file, read_media_file | |
| # Load environment variables | |
| load_dotenv() | |
| app = FastAPI(title="Agent Chat", description="AI Agent with Tools") | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global session manager (shared across requests) | |
| session_manager = InMemorySessionManager() | |
| # Upload directory for files | |
| UPLOAD_DIR = Path(__file__).parent / "uploads" | |
| UPLOAD_DIR.mkdir(exist_ok=True) | |
| # Available tools | |
| TOOLS = [calculator, search_web, read_file, list_files, unzip_file, read_media_file] | |
| # Create agent | |
| def create_agent(use_session: bool = True) -> Agent: | |
| """Create an agent instance.""" | |
| # Include the actual upload directory path in instructions | |
| upload_path = str(UPLOAD_DIR.absolute()) | |
| instructions = f"""You are a helpful AI assistant with access to various tools. | |
| You can: | |
| - Perform calculations using the calculator | |
| - Search the web for current information | |
| - Read excel files using the read_file tool | |
| - List files in directories using the list_files tool | |
| - Extract zip files using the unzip_file tool | |
| - Read pdf using read_media_file | |
| IMPORTANT - Uploaded files location: | |
| Files uploaded by users are stored at: {upload_path} | |
| To see uploaded files, use: list_files("{upload_path}") | |
| To read a file, use: read_file("{upload_path}/filename.ext") | |
| Always be helpful and use your tools when needed to provide accurate answers.""" | |
| return Agent( | |
| model=LlmClient(model="gpt-4o-mini"), | |
| tools=TOOLS, | |
| instructions=instructions, | |
| max_steps=10, | |
| session_manager=session_manager if use_session else None | |
| ) | |
| # Pydantic models for API | |
| class ChatRequest(BaseModel): | |
| message: str | |
| session_id: Optional[str] = None | |
| use_session: bool = True | |
| class ChatResponse(BaseModel): | |
| response: str | |
| session_id: str | |
| events_count: int | |
| tools_used: List[str] | |
| trace_text: str = "" # Simple text-based trace like display_trace | |
| class ToolInfo(BaseModel): | |
| name: str | |
| description: str | |
| class SessionInfo(BaseModel): | |
| session_id: str | |
| events_count: int | |
| created_at: str | |
| # API Endpoints | |
| async def root(): | |
| """Serve the chat interface.""" | |
| return FileResponse(Path(__file__).parent / "static" / "index.html") | |
| async def get_tools() -> List[ToolInfo]: | |
| """Get list of available tools.""" | |
| return [ | |
| ToolInfo( | |
| name=tool.name, | |
| description=tool.description[:100] + "..." if len(tool.description) > 100 else tool.description | |
| ) | |
| for tool in TOOLS | |
| ] | |
| async def chat(request: ChatRequest) -> ChatResponse: | |
| """Send a message to the agent.""" | |
| # Generate or use provided session ID | |
| session_id = request.session_id or str(uuid.uuid4()) | |
| # Create agent | |
| agent = create_agent(use_session=request.use_session) | |
| try: | |
| # Run the agent | |
| if request.use_session: | |
| result = await agent.run(request.message, session_id=session_id) | |
| else: | |
| result = await agent.run(request.message) | |
| # Extract tools used | |
| tools_used = [] | |
| for event in result.context.events: | |
| for item in event.content: | |
| if hasattr(item, 'name') and item.type == "tool_call": | |
| if item.name not in tools_used: | |
| tools_used.append(item.name) | |
| # Use your format_trace function directly! | |
| trace_text = format_trace(result.context) | |
| return ChatResponse( | |
| response=str(result.output) if result.output else "I couldn't generate a response.", | |
| session_id=session_id, | |
| events_count=len(result.context.events), | |
| tools_used=tools_used, | |
| trace_text=trace_text | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_file(file: UploadFile = File(...)): | |
| """Upload a file for the agent to access.""" | |
| # Save file to uploads directory | |
| file_path = UPLOAD_DIR / file.filename | |
| try: | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| return { | |
| "filename": file.filename, | |
| "path": str(file_path), | |
| "size": file_path.stat().st_size, | |
| "message": f"File uploaded successfully. You can reference it at: {file_path}" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_uploads(): | |
| """List uploaded files.""" | |
| files = [] | |
| for f in UPLOAD_DIR.iterdir(): | |
| if f.is_file() and not f.name.startswith('.'): | |
| files.append({ | |
| "name": f.name, | |
| "path": str(f), | |
| "size": f.stat().st_size | |
| }) | |
| return files | |
| async def delete_upload(filename: str): | |
| """Delete an uploaded file.""" | |
| file_path = UPLOAD_DIR / filename | |
| if file_path.exists(): | |
| file_path.unlink() | |
| return {"message": f"Deleted {filename}"} | |
| raise HTTPException(status_code=404, detail="File not found") | |
| async def list_sessions() -> List[SessionInfo]: | |
| """List all active sessions.""" | |
| sessions = [] | |
| for sid, session in session_manager._sessions.items(): | |
| sessions.append(SessionInfo( | |
| session_id=sid, | |
| events_count=len(session.events), | |
| created_at=session.created_at.isoformat() | |
| )) | |
| return sessions | |
| async def delete_session(session_id: str): | |
| """Delete a session to clear conversation history.""" | |
| if session_id in session_manager._sessions: | |
| del session_manager._sessions[session_id] | |
| return {"message": f"Session {session_id} cleared"} | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # Mount static files | |
| static_dir = Path(__file__).parent / "static" | |
| static_dir.mkdir(exist_ok=True) | |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |