Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException, Depends, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from typing import Dict, List, Optional, Any | |
| from datetime import datetime, timedelta | |
| import jwt | |
| import os | |
| import uvicorn | |
| # Import our existing modules | |
| from enhanced_rag_system import EnhancedRAGSystem | |
| from auth_system import AuthSystem | |
| app = FastAPI( | |
| title="FinSolve RAG API", | |
| description="Role-Based Access Control RAG System API", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # CORS middleware for web app integration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your Streamlit URL | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # JWT configuration | |
| JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production") | |
| JWT_ALGORITHM = "HS256" | |
| JWT_EXPIRATION_DELTA = timedelta(hours=24) | |
| # Security | |
| security = HTTPBearer() | |
| # Global instances | |
| rag_system = None | |
| auth_system = None | |
| async def startup_event(): | |
| """Initialize systems on startup""" | |
| global rag_system, auth_system | |
| print("π Starting FastAPI RAG System...") | |
| # Initialize authentication system | |
| auth_system = AuthSystem() | |
| print("β Authentication system initialized") | |
| # Initialize RAG system | |
| rag_system = EnhancedRAGSystem() | |
| rag_system.initialize_system() | |
| print("β RAG system initialized") | |
| # Pydantic models | |
| class LoginRequest(BaseModel): | |
| username: str | |
| password: str | |
| class QueryRequest(BaseModel): | |
| query: str | |
| class FeedbackRequest(BaseModel): | |
| query: str | |
| response: str | |
| rating: int | |
| class TokenResponse(BaseModel): | |
| access_token: str | |
| token_type: str | |
| user_info: Dict | |
| class QueryResponse(BaseModel): | |
| response: str | |
| sources: List[str] | |
| visualization: Optional[str] = None | |
| table: Optional[str] = None | |
| processing_time: float | |
| query_intent: str | |
| class UserInfo(BaseModel): | |
| username: str | |
| role: str | |
| full_name: str | |
| department: str | |
| class HealthResponse(BaseModel): | |
| status: str | |
| timestamp: str | |
| components: Dict[str, bool] | |
| version: str | |
| # Utility functions | |
| def create_access_token(username: str, role: str) -> str: | |
| """Create JWT access token""" | |
| payload = { | |
| "username": username, | |
| "role": role, | |
| "exp": datetime.utcnow() + JWT_EXPIRATION_DELTA, | |
| "iat": datetime.utcnow() | |
| } | |
| token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) | |
| return token | |
| def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict: | |
| """Verify JWT token and return user info""" | |
| try: | |
| token = credentials.credentials | |
| payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) | |
| username = payload.get("username") | |
| role = payload.get("role") | |
| if username is None or role is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid token", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return {"username": username, "role": role} | |
| except jwt.ExpiredSignatureError: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Token expired", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| except jwt.JWTError: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid token", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # API Endpoints | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "FinSolve RAG API", | |
| "version": "1.0.0", | |
| "status": "active", | |
| "docs": "/docs" | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| system_status = rag_system.get_system_status() if rag_system else {} | |
| return HealthResponse( | |
| status="healthy" if system_status.get("system_initialized", False) else "unhealthy", | |
| timestamp=datetime.now().isoformat(), | |
| components={ | |
| "rag_system": system_status.get("system_initialized", False), | |
| "auth_system": auth_system is not None, | |
| "vector_store": system_status.get("vector_store_available", False), | |
| "llm": system_status.get("llm_available", False) | |
| }, | |
| version="1.0.0" | |
| ) | |
| # Authentication endpoints | |
| async def login(login_data: LoginRequest): | |
| """Authenticate user and return access token""" | |
| if not auth_system.authenticate(login_data.username, login_data.password): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid credentials" | |
| ) | |
| user_info = auth_system.get_user_info(login_data.username) | |
| access_token = create_access_token(login_data.username, user_info["role"]) | |
| return TokenResponse( | |
| access_token=access_token, | |
| token_type="bearer", | |
| user_info=user_info | |
| ) | |
| async def get_current_user(current_user: Dict = Depends(verify_token)): | |
| """Get current user information""" | |
| user_info = auth_system.get_user_info(current_user["username"]) | |
| if not user_info: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="User not found" | |
| ) | |
| return UserInfo(**user_info, username=current_user["username"]) | |
| # Chat endpoints | |
| async def process_query(query_data: QueryRequest, current_user: Dict = Depends(verify_token)): | |
| """Process RAG query""" | |
| if not rag_system: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="RAG system not initialized" | |
| ) | |
| start_time = datetime.now() | |
| try: | |
| response, sources, visualization, table = rag_system.query( | |
| query_data.query, | |
| current_user["role"] | |
| ) | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| query_intent = rag_system._classify_query_intent(query_data.query) | |
| return QueryResponse( | |
| response=response, | |
| sources=sources, | |
| visualization=visualization, | |
| table=table, | |
| processing_time=processing_time, | |
| query_intent=query_intent | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error processing query: {str(e)}" | |
| ) | |
| async def submit_feedback(feedback_data: FeedbackRequest, current_user: Dict = Depends(verify_token)): | |
| """Submit feedback for a response""" | |
| if not rag_system: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="RAG system not initialized" | |
| ) | |
| if not 1 <= feedback_data.rating <= 5: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Rating must be between 1 and 5" | |
| ) | |
| try: | |
| rag_system.store_feedback( | |
| feedback_data.query, | |
| feedback_data.response, | |
| feedback_data.rating, | |
| current_user["role"] | |
| ) | |
| return {"message": "Feedback submitted successfully"} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error storing feedback: {str(e)}" | |
| ) | |
| # Document endpoints | |
| async def get_accessible_documents(current_user: Dict = Depends(verify_token)): | |
| """Get documents accessible to current user""" | |
| if not rag_system: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="RAG system not initialized" | |
| ) | |
| try: | |
| accessible_docs = auth_system.get_accessible_documents(current_user["role"]) | |
| doc_info = rag_system.get_available_documents_for_role(current_user["role"]) | |
| return { | |
| "role": current_user["role"], | |
| "accessible_documents": accessible_docs, | |
| "document_details": doc_info | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error retrieving documents: {str(e)}" | |
| ) | |
| # System endpoints | |
| async def get_system_status(current_user: Dict = Depends(verify_token)): | |
| """Get system status (admin only)""" | |
| if current_user["role"] not in ["C-Level", "Engineering"]: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Insufficient privileges" | |
| ) | |
| if not rag_system: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="RAG system not initialized" | |
| ) | |
| try: | |
| status = rag_system.get_system_status() | |
| # Add API-specific metrics | |
| status["api_metrics"] = { | |
| "fastapi_version": "Real FastAPI Instance", | |
| "endpoints_available": len(app.routes), | |
| "middleware_count": len(app.middleware_stack), | |
| "startup_complete": True | |
| } | |
| return status | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error retrieving system status: {str(e)}" | |
| ) | |
| async def get_system_metrics(current_user: Dict = Depends(verify_token)): | |
| """Get detailed system metrics (admin only)""" | |
| if current_user["role"] not in ["C-Level", "Engineering"]: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Insufficient privileges" | |
| ) | |
| return { | |
| "api_info": { | |
| "title": app.title, | |
| "version": app.version, | |
| "description": app.description, | |
| "total_routes": len(app.routes), | |
| "middleware_stack": len(app.middleware_stack) | |
| }, | |
| "system_health": { | |
| "rag_system": rag_system is not None, | |
| "auth_system": auth_system is not None, | |
| "timestamp": datetime.now().isoformat() | |
| }, | |
| "endpoints": [ | |
| { | |
| "path": route.path, | |
| "methods": list(route.methods) if hasattr(route, 'methods') else ["GET"], | |
| "name": route.name | |
| } | |
| for route in app.routes | |
| if hasattr(route, 'path') | |
| ] | |
| } | |
| # Error handlers | |
| async def http_exception_handler(request, exc): | |
| """Custom HTTP exception handler""" | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={ | |
| "error": exc.detail, | |
| "status_code": exc.status_code, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| ) | |
| async def general_exception_handler(request, exc): | |
| """General exception handler""" | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "Internal server error", | |
| "detail": str(exc), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| # For development - in production, use a proper ASGI server | |
| uvicorn.run( | |
| "api:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True, | |
| log_level="info" | |
| ) |