from __future__ import annotations import os from functools import lru_cache from dataclasses import dataclass from typing import Literal @dataclass(frozen=True, slots=True) class Settings: model_name: str = "Qwen/Qwen2.5-0.5B-Instruct" max_trace_tokens: int = 1024 max_sentences: int = 40 take_log: bool = True device_preference: Literal["auto", "cuda", "cpu", "mps"] = "auto" dtype_preference: Literal["auto", "float32", "float16", "bfloat16"] = "auto" attn_implementation: str = "eager" trust_remote_code: bool = True low_cpu_mem_usage: bool = True preload_model: bool = False api_host: str = "0.0.0.0" api_port: int = 7860 database_path: str = "data/app.db" job_workers: int = 1 max_queued_jobs: int = 8 max_active_jobs_per_user: int = 2 require_auth: bool = True public_api_enabled: bool = True DEFAULT_SETTINGS = Settings() @lru_cache(maxsize=1) def get_settings() -> Settings: take_log = os.getenv("TAKE_LOG", "true").strip().lower() in {"1", "true", "yes", "on"} trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "true").strip().lower() in {"1", "true", "yes", "on"} low_cpu_mem_usage = os.getenv("LOW_CPU_MEM_USAGE", "true").strip().lower() in {"1", "true", "yes", "on"} require_auth = os.getenv("REQUIRE_AUTH", "true").strip().lower() in {"1", "true", "yes", "on"} public_api_enabled = os.getenv("PUBLIC_API_ENABLED", "true").strip().lower() in {"1", "true", "yes", "on"} return Settings( model_name=os.getenv("MODEL_NAME", DEFAULT_SETTINGS.model_name), max_trace_tokens=int(os.getenv("MAX_TRACE_TOKENS", DEFAULT_SETTINGS.max_trace_tokens)), max_sentences=int(os.getenv("MAX_SENTENCES", DEFAULT_SETTINGS.max_sentences)), take_log=take_log, device_preference=os.getenv("DEVICE_PREFERENCE", DEFAULT_SETTINGS.device_preference), # type: ignore[arg-type] dtype_preference=os.getenv("DTYPE_PREFERENCE", DEFAULT_SETTINGS.dtype_preference), # type: ignore[arg-type] attn_implementation=os.getenv("ATTN_IMPLEMENTATION", DEFAULT_SETTINGS.attn_implementation), trust_remote_code=trust_remote_code, low_cpu_mem_usage=low_cpu_mem_usage, preload_model=os.getenv("PRELOAD_MODEL", "false").strip().lower() in {"1", "true", "yes", "on"}, api_host=os.getenv("API_HOST", DEFAULT_SETTINGS.api_host), api_port=int(os.getenv("API_PORT", DEFAULT_SETTINGS.api_port)), database_path=os.getenv("DATABASE_PATH", DEFAULT_SETTINGS.database_path), job_workers=int(os.getenv("JOB_WORKERS", DEFAULT_SETTINGS.job_workers)), max_queued_jobs=int(os.getenv("MAX_QUEUED_JOBS", DEFAULT_SETTINGS.max_queued_jobs)), max_active_jobs_per_user=int( os.getenv("MAX_ACTIVE_JOBS_PER_USER", DEFAULT_SETTINGS.max_active_jobs_per_user) ), require_auth=require_auth, public_api_enabled=public_api_enabled, )