from fastapi import FastAPI, HTTPException, Depends, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import Dict, List, Optional, Any from datetime import datetime, timedelta import jwt import os import uvicorn # Import our existing modules from enhanced_rag_system import EnhancedRAGSystem from auth_system import AuthSystem app = FastAPI( title="FinSolve RAG API", description="Role-Based Access Control RAG System API", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # CORS middleware for web app integration app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify your Streamlit URL allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # JWT configuration JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production") JWT_ALGORITHM = "HS256" JWT_EXPIRATION_DELTA = timedelta(hours=24) # Security security = HTTPBearer() # Global instances rag_system = None auth_system = None @app.on_event("startup") async def startup_event(): """Initialize systems on startup""" global rag_system, auth_system print("🚀 Starting FastAPI RAG System...") # Initialize authentication system auth_system = AuthSystem() print("✅ Authentication system initialized") # Initialize RAG system rag_system = EnhancedRAGSystem() rag_system.initialize_system() print("✅ RAG system initialized") # Pydantic models class LoginRequest(BaseModel): username: str password: str class QueryRequest(BaseModel): query: str class FeedbackRequest(BaseModel): query: str response: str rating: int class TokenResponse(BaseModel): access_token: str token_type: str user_info: Dict class QueryResponse(BaseModel): response: str sources: List[str] visualization: Optional[str] = None table: Optional[str] = None processing_time: float query_intent: str class UserInfo(BaseModel): username: str role: str full_name: str department: str class HealthResponse(BaseModel): status: str timestamp: str components: Dict[str, bool] version: str # Utility functions def create_access_token(username: str, role: str) -> str: """Create JWT access token""" payload = { "username": username, "role": role, "exp": datetime.utcnow() + JWT_EXPIRATION_DELTA, "iat": datetime.utcnow() } token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) return token def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict: """Verify JWT token and return user info""" try: token = credentials.credentials payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) username = payload.get("username") role = payload.get("role") if username is None or role is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", headers={"WWW-Authenticate": "Bearer"}, ) return {"username": username, "role": role} except jwt.ExpiredSignatureError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired", headers={"WWW-Authenticate": "Bearer"}, ) except jwt.JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", headers={"WWW-Authenticate": "Bearer"}, ) # API Endpoints @app.get("/", response_model=Dict) async def root(): """Root endpoint""" return { "message": "FinSolve RAG API", "version": "1.0.0", "status": "active", "docs": "/docs" } @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" system_status = rag_system.get_system_status() if rag_system else {} return HealthResponse( status="healthy" if system_status.get("system_initialized", False) else "unhealthy", timestamp=datetime.now().isoformat(), components={ "rag_system": system_status.get("system_initialized", False), "auth_system": auth_system is not None, "vector_store": system_status.get("vector_store_available", False), "llm": system_status.get("llm_available", False) }, version="1.0.0" ) # Authentication endpoints @app.post("/auth/login", response_model=TokenResponse) async def login(login_data: LoginRequest): """Authenticate user and return access token""" if not auth_system.authenticate(login_data.username, login_data.password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials" ) user_info = auth_system.get_user_info(login_data.username) access_token = create_access_token(login_data.username, user_info["role"]) return TokenResponse( access_token=access_token, token_type="bearer", user_info=user_info ) @app.get("/auth/me", response_model=UserInfo) async def get_current_user(current_user: Dict = Depends(verify_token)): """Get current user information""" user_info = auth_system.get_user_info(current_user["username"]) if not user_info: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) return UserInfo(**user_info, username=current_user["username"]) # Chat endpoints @app.post("/chat/query", response_model=QueryResponse) async def process_query(query_data: QueryRequest, current_user: Dict = Depends(verify_token)): """Process RAG query""" if not rag_system: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="RAG system not initialized" ) start_time = datetime.now() try: response, sources, visualization, table = rag_system.query( query_data.query, current_user["role"] ) processing_time = (datetime.now() - start_time).total_seconds() query_intent = rag_system._classify_query_intent(query_data.query) return QueryResponse( response=response, sources=sources, visualization=visualization, table=table, processing_time=processing_time, query_intent=query_intent ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error processing query: {str(e)}" ) @app.post("/chat/feedback") async def submit_feedback(feedback_data: FeedbackRequest, current_user: Dict = Depends(verify_token)): """Submit feedback for a response""" if not rag_system: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="RAG system not initialized" ) if not 1 <= feedback_data.rating <= 5: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Rating must be between 1 and 5" ) try: rag_system.store_feedback( feedback_data.query, feedback_data.response, feedback_data.rating, current_user["role"] ) return {"message": "Feedback submitted successfully"} except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error storing feedback: {str(e)}" ) # Document endpoints @app.get("/documents/accessible") async def get_accessible_documents(current_user: Dict = Depends(verify_token)): """Get documents accessible to current user""" if not rag_system: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="RAG system not initialized" ) try: accessible_docs = auth_system.get_accessible_documents(current_user["role"]) doc_info = rag_system.get_available_documents_for_role(current_user["role"]) return { "role": current_user["role"], "accessible_documents": accessible_docs, "document_details": doc_info } except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error retrieving documents: {str(e)}" ) # System endpoints @app.get("/system/status") async def get_system_status(current_user: Dict = Depends(verify_token)): """Get system status (admin only)""" if current_user["role"] not in ["C-Level", "Engineering"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient privileges" ) if not rag_system: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="RAG system not initialized" ) try: status = rag_system.get_system_status() # Add API-specific metrics status["api_metrics"] = { "fastapi_version": "Real FastAPI Instance", "endpoints_available": len(app.routes), "middleware_count": len(app.middleware_stack), "startup_complete": True } return status except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error retrieving system status: {str(e)}" ) @app.get("/system/metrics") async def get_system_metrics(current_user: Dict = Depends(verify_token)): """Get detailed system metrics (admin only)""" if current_user["role"] not in ["C-Level", "Engineering"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient privileges" ) return { "api_info": { "title": app.title, "version": app.version, "description": app.description, "total_routes": len(app.routes), "middleware_stack": len(app.middleware_stack) }, "system_health": { "rag_system": rag_system is not None, "auth_system": auth_system is not None, "timestamp": datetime.now().isoformat() }, "endpoints": [ { "path": route.path, "methods": list(route.methods) if hasattr(route, 'methods') else ["GET"], "name": route.name } for route in app.routes if hasattr(route, 'path') ] } # Error handlers @app.exception_handler(HTTPException) async def http_exception_handler(request, exc): """Custom HTTP exception handler""" return JSONResponse( status_code=exc.status_code, content={ "error": exc.detail, "status_code": exc.status_code, "timestamp": datetime.now().isoformat() } ) @app.exception_handler(Exception) async def general_exception_handler(request, exc): """General exception handler""" return JSONResponse( status_code=500, content={ "error": "Internal server error", "detail": str(exc), "timestamp": datetime.now().isoformat() } ) if __name__ == "__main__": # For development - in production, use a proper ASGI server uvicorn.run( "api:app", host="0.0.0.0", port=8000, reload=True, log_level="info" )