| import os | |
| import re | |
| from typing import Any, Dict, List, Literal, Optional | |
| import structlog | |
| import yaml | |
| from pydantic import BaseModel, Field | |
| LOGGER = structlog.getLogger(__name__) | |
| _var_matcher = re.compile(r"\${([^}^{]+)}") | |
| _tag_matcher = re.compile(r"[^$]*\${([^}^{]+)}.*") | |
| class RateLimitConfig(BaseModel): | |
| enabled: bool = Field(default=False) | |
| limit: str = Field(default="100/minute") | |
| class CacheConfig(BaseModel): | |
| ttl: int = Field(default=60) | |
| max_size: Optional[int] = Field(default=None) | |
| class AuthConfig(BaseModel): | |
| type: Literal["http_bearer", "http_basic"] = Field() | |
| token: Optional[str] = Field(default=None) | |
| username: Optional[str] = Field(default=None) | |
| password: Optional[str] = Field(default=None) | |
| class TracingConfig(BaseModel): | |
| exporter: Literal["otel_http", "console"] = Field(default="console") | |
| endpoint: Optional[str] = Field(default=None) | |
| class MetricsConfig(BaseModel): | |
| exporter: Literal["otel_http", "prometheus", "console"] = Field(default="console") | |
| endpoint: Optional[str] = Field(default=None) | |
| class AppConfig(BaseModel): | |
| name: Optional[str] = Field(default="LLM Guard API") | |
| port: Optional[int] = Field(default=7860) | |
| log_level: Optional[str] = Field(default="INFO") | |
| scan_fail_fast: Optional[bool] = Field(default=False) | |
| scan_prompt_timeout: Optional[int] = Field(default=10) | |
| scan_output_timeout: Optional[int] = Field(default=30) | |
| class ScannerConfig(BaseModel): | |
| type: str | |
| params: Optional[Dict] = Field(default_factory=dict) | |
| class Config(BaseModel): | |
| input_scanners: List[ScannerConfig] = Field() | |
| output_scanners: List[ScannerConfig] = Field() | |
| rate_limit: RateLimitConfig = Field(default_factory=RateLimitConfig) | |
| cache: CacheConfig = Field(default_factory=CacheConfig) | |
| auth: Optional[AuthConfig] = Field(default=None) | |
| app: AppConfig = Field(default_factory=AppConfig) | |
| tracing: Optional[TracingConfig] = Field(default=None) | |
| metrics: Optional[MetricsConfig] = Field(default=None) | |
| def _path_constructor(_loader: Any, node: Any): | |
| def replace_fn(match): | |
| envparts = f"{match.group(1)}:".split(":") | |
| return os.environ.get(envparts[0], envparts[1]) | |
| return _var_matcher.sub(replace_fn, node.value) | |
| def load_yaml(filename: str) -> dict: | |
| yaml.add_implicit_resolver("!envvar", _tag_matcher, None, yaml.SafeLoader) | |
| yaml.add_constructor("!envvar", _path_constructor, yaml.SafeLoader) | |
| try: | |
| with open(filename, "r") as f: | |
| return yaml.safe_load(f.read()) | |
| except (FileNotFoundError, PermissionError, yaml.YAMLError) as exc: | |
| LOGGER.error("Error loading YAML file", exception=exc) | |
| return dict() | |
| def get_config(file_name: str) -> Optional[Config]: | |
| LOGGER.debug("Loading config file", file_name=file_name) | |
| conf = load_yaml(file_name) | |
| if conf == {}: | |
| return None | |
| return Config(**conf) | |