cot-anc / app /core /config.py
BART-ender's picture
chore: switch default model to Qwen/Qwen2.5-0.5B-Instruct
b68ad1b verified
from __future__ import annotations
import os
from functools import lru_cache
from dataclasses import dataclass
from typing import Literal
@dataclass(frozen=True, slots=True)
class Settings:
model_name: str = "Qwen/Qwen2.5-0.5B-Instruct"
max_trace_tokens: int = 1024
max_sentences: int = 40
take_log: bool = True
device_preference: Literal["auto", "cuda", "cpu", "mps"] = "auto"
dtype_preference: Literal["auto", "float32", "float16", "bfloat16"] = "auto"
attn_implementation: str = "eager"
trust_remote_code: bool = True
low_cpu_mem_usage: bool = True
preload_model: bool = False
api_host: str = "0.0.0.0"
api_port: int = 7860
database_path: str = "data/app.db"
job_workers: int = 1
max_queued_jobs: int = 8
max_active_jobs_per_user: int = 2
require_auth: bool = True
public_api_enabled: bool = True
DEFAULT_SETTINGS = Settings()
@lru_cache(maxsize=1)
def get_settings() -> Settings:
take_log = os.getenv("TAKE_LOG", "true").strip().lower() in {"1", "true", "yes", "on"}
trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "true").strip().lower() in {"1", "true", "yes", "on"}
low_cpu_mem_usage = os.getenv("LOW_CPU_MEM_USAGE", "true").strip().lower() in {"1", "true", "yes", "on"}
require_auth = os.getenv("REQUIRE_AUTH", "true").strip().lower() in {"1", "true", "yes", "on"}
public_api_enabled = os.getenv("PUBLIC_API_ENABLED", "true").strip().lower() in {"1", "true", "yes", "on"}
return Settings(
model_name=os.getenv("MODEL_NAME", DEFAULT_SETTINGS.model_name),
max_trace_tokens=int(os.getenv("MAX_TRACE_TOKENS", DEFAULT_SETTINGS.max_trace_tokens)),
max_sentences=int(os.getenv("MAX_SENTENCES", DEFAULT_SETTINGS.max_sentences)),
take_log=take_log,
device_preference=os.getenv("DEVICE_PREFERENCE", DEFAULT_SETTINGS.device_preference), # type: ignore[arg-type]
dtype_preference=os.getenv("DTYPE_PREFERENCE", DEFAULT_SETTINGS.dtype_preference), # type: ignore[arg-type]
attn_implementation=os.getenv("ATTN_IMPLEMENTATION", DEFAULT_SETTINGS.attn_implementation),
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage,
preload_model=os.getenv("PRELOAD_MODEL", "false").strip().lower() in {"1", "true", "yes", "on"},
api_host=os.getenv("API_HOST", DEFAULT_SETTINGS.api_host),
api_port=int(os.getenv("API_PORT", DEFAULT_SETTINGS.api_port)),
database_path=os.getenv("DATABASE_PATH", DEFAULT_SETTINGS.database_path),
job_workers=int(os.getenv("JOB_WORKERS", DEFAULT_SETTINGS.job_workers)),
max_queued_jobs=int(os.getenv("MAX_QUEUED_JOBS", DEFAULT_SETTINGS.max_queued_jobs)),
max_active_jobs_per_user=int(
os.getenv("MAX_ACTIVE_JOBS_PER_USER", DEFAULT_SETTINGS.max_active_jobs_per_user)
),
require_auth=require_auth,
public_api_enabled=public_api_enabled,
)