Spaces:
Sleeping
Sleeping
| import argparse | |
| import asyncio | |
| import concurrent.futures | |
| import time | |
| from typing import Annotated | |
| import structlog | |
| from fastapi import Depends, FastAPI, HTTPException, Response, status | |
| from fastapi.encoders import jsonable_encoder | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.security import ( | |
| HTTPAuthorizationCredentials, | |
| HTTPBasic, | |
| HTTPBasicCredentials, | |
| HTTPBearer, | |
| ) | |
| from opentelemetry import metrics | |
| from prometheus_client import CONTENT_TYPE_LATEST, REGISTRY, generate_latest | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.errors import RateLimitExceeded | |
| from slowapi.middleware import SlowAPIMiddleware | |
| from slowapi.util import get_remote_address | |
| from starlette.exceptions import HTTPException as StarletteHTTPException | |
| from llm_guard import scan_output, scan_prompt | |
| from llm_guard.vault import Vault | |
| from .cache import InMemoryCache | |
| from .config import AuthConfig, get_config | |
| from .otel import configure_otel, instrument_app | |
| from .scanner import get_input_scanners, get_output_scanners | |
| from .schemas import ( | |
| AnalyzeOutputRequest, | |
| AnalyzeOutputResponse, | |
| AnalyzePromptRequest, | |
| AnalyzePromptResponse, | |
| ) | |
| from .util import configure_logger | |
| from .version import __version__ | |
| vault = Vault() | |
| parser = argparse.ArgumentParser(description="LLM Guard API") | |
| parser.add_argument("config", type=str, help="Path to the configuration file") | |
| args = parser.parse_args() | |
| scanners_config_file = args.config | |
| config = get_config(scanners_config_file) | |
| LOGGER = structlog.getLogger(__name__) | |
| log_level = config.app.log_level | |
| is_debug = log_level == "DEBUG" | |
| configure_logger(log_level) | |
| configure_otel(config.app.name, config.tracing, config.metrics) | |
| input_scanners = get_input_scanners(config.input_scanners, vault) | |
| output_scanners = get_output_scanners(config.output_scanners, vault) | |
| meter = metrics.get_meter_provider().get_meter(__name__) | |
| scanners_valid_counter = meter.create_counter( | |
| name="scanners.valid", | |
| unit="1", | |
| description="measures the number of valid scanners", | |
| ) | |
| def create_app() -> FastAPI: | |
| cache = InMemoryCache( | |
| max_size=config.cache.max_size, | |
| expiration_time=config.cache.ttl, | |
| ) | |
| if config.app.scan_fail_fast: | |
| LOGGER.debug("Scan fail_fast mode is enabled") | |
| app = FastAPI( | |
| title=config.app.name, | |
| description="API to run LLM Guard scanners.", | |
| debug=is_debug, | |
| version=__version__, | |
| openapi_url="/openapi.json" if is_debug else None, # hide docs in production | |
| ) | |
| register_routes(app, cache, input_scanners, output_scanners) | |
| return app | |
| def _check_auth_function(auth_config: AuthConfig) -> callable: | |
| async def check_auth_noop() -> bool: | |
| return True | |
| if not auth_config: | |
| return check_auth_noop | |
| if auth_config.type == "http_bearer": | |
| credentials_type = Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer())] | |
| elif auth_config.type == "http_basic": | |
| credentials_type = Annotated[HTTPBasicCredentials, Depends(HTTPBasic())] | |
| else: | |
| raise ValueError(f"Invalid auth type: {auth_config.type}") | |
| async def check_auth(credentials: credentials_type) -> bool: | |
| if auth_config.type == "http_bearer": | |
| if credentials.credentials != auth_config.token: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" | |
| ) | |
| elif auth_config.type == "http_basic": | |
| if ( | |
| credentials.username != auth_config.username | |
| or credentials.password != auth_config.password | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Username or Password" | |
| ) | |
| return True | |
| return check_auth | |
| def register_routes( | |
| app: FastAPI, cache: InMemoryCache, input_scanners: list, output_scanners: list | |
| ): | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["Authorization", "Content-Type"], | |
| ) | |
| limiter = Limiter(key_func=get_remote_address, default_limits=[config.rate_limit.limit]) | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| if bool(config.rate_limit.enabled): | |
| app.add_middleware(SlowAPIMiddleware) | |
| check_auth = _check_auth_function(config.auth) | |
| async def read_root(): | |
| return {"name": "LLM Guard API"} | |
| async def healthcheck(): | |
| return JSONResponse({"status": "alive"}) | |
| async def liveliness(): | |
| return JSONResponse({"status": "ready"}) | |
| async def analyze_output( | |
| request: AnalyzeOutputRequest, _: Annotated[bool, Depends(check_auth)] | |
| ) -> AnalyzeOutputResponse: | |
| LOGGER.debug("Received analyze output request", request=request) | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| loop = asyncio.get_event_loop() | |
| try: | |
| start_time = time.time() | |
| sanitized_output, results_valid, results_score = await asyncio.wait_for( | |
| loop.run_in_executor( | |
| executor, | |
| scan_output, | |
| output_scanners, | |
| request.prompt, | |
| request.output, | |
| config.app.scan_fail_fast, | |
| ), | |
| timeout=config.app.scan_output_timeout, | |
| ) | |
| for scanner, valid in results_valid.items(): | |
| scanners_valid_counter.add( | |
| 1, {"source": "output", "valid": valid, "scanner": scanner} | |
| ) | |
| response = AnalyzeOutputResponse( | |
| sanitized_output=sanitized_output, | |
| is_valid=all(results_valid.values()), | |
| scanners=results_score, | |
| ) | |
| elapsed_time = time.time() - start_time | |
| LOGGER.debug( | |
| "Sanitized response", | |
| scores=results_score, | |
| elapsed_time_seconds=round(elapsed_time, 6), | |
| ) | |
| except asyncio.TimeoutError: | |
| raise HTTPException( | |
| status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout." | |
| ) | |
| return response | |
| async def analyze_prompt( | |
| request: AnalyzePromptRequest, | |
| _: Annotated[bool, Depends(check_auth)], | |
| response: Response, | |
| ) -> AnalyzePromptResponse: | |
| LOGGER.debug("Received analyze prompt request", request=request) | |
| cached_result = cache.get(request.prompt) | |
| if cached_result: | |
| LOGGER.debug("Response was found in cache") | |
| response.headers["X-Cache-Hit"] = "true" | |
| return AnalyzePromptResponse(**cached_result) | |
| response.headers["X-Cache-Hit"] = "false" | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| loop = asyncio.get_event_loop() | |
| try: | |
| start_time = time.time() | |
| sanitized_prompt, results_valid, results_score = await asyncio.wait_for( | |
| loop.run_in_executor( | |
| executor, | |
| scan_prompt, | |
| input_scanners, | |
| request.prompt, | |
| config.app.scan_fail_fast, | |
| ), | |
| timeout=config.app.scan_prompt_timeout, | |
| ) | |
| for scanner, valid in results_valid.items(): | |
| scanners_valid_counter.add( | |
| 1, {"source": "input", "valid": valid, "scanner": scanner} | |
| ) | |
| response = AnalyzePromptResponse( | |
| sanitized_prompt=sanitized_prompt, | |
| is_valid=all(results_valid.values()), | |
| scanners=results_score, | |
| ) | |
| cache.set(request.prompt, response.dict()) | |
| elapsed_time = time.time() - start_time | |
| LOGGER.debug( | |
| "Sanitized prompt response returned", | |
| scores=results_score, | |
| elapsed_time_seconds=round(elapsed_time, 6), | |
| ) | |
| except asyncio.TimeoutError: | |
| raise HTTPException( | |
| status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout." | |
| ) | |
| return response | |
| if config.metrics and config.metrics.exporter == "prometheus": | |
| async def metrics(): | |
| return Response( | |
| content=generate_latest(REGISTRY), headers={"Content-Type": CONTENT_TYPE_LATEST} | |
| ) | |
| async def shutdown_event(): | |
| LOGGER.info("Shutting down app...") | |
| async def http_exception_handler(request, exc): | |
| LOGGER.warning( | |
| "HTTP exception", exception_status_code=exc.status_code, exception_detail=exc.detail | |
| ) | |
| return JSONResponse( | |
| {"message": str(exc.detail), "details": None}, status_code=exc.status_code | |
| ) | |
| async def validation_exception_handler(request, exc): | |
| LOGGER.warning("Invalid request", exception=str(exc)) | |
| response = {"message": "Validation failed", "details": exc.errors()} | |
| return JSONResponse( | |
| jsonable_encoder(response), status_code=status.HTTP_422_UNPROCESSABLE_ENTITY | |
| ) | |
| app = create_app() | |
| instrument_app(app) | |
| def run_app(): | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=config.app.port, | |
| server_header=False, | |
| log_level=log_level.lower(), | |
| proxy_headers=True, | |
| forwarded_allow_ips="*", | |
| timeout_keep_alive=2, | |
| ) | |