File size: 2,943 Bytes
fda8fb3
 
 
 
 
 
 
 
 
 
b68ad1b
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import annotations

import os
from functools import lru_cache
from dataclasses import dataclass
from typing import Literal


@dataclass(frozen=True, slots=True)
class Settings:
    model_name: str = "Qwen/Qwen2.5-0.5B-Instruct"
    max_trace_tokens: int = 1024
    max_sentences: int = 40
    take_log: bool = True
    device_preference: Literal["auto", "cuda", "cpu", "mps"] = "auto"
    dtype_preference: Literal["auto", "float32", "float16", "bfloat16"] = "auto"
    attn_implementation: str = "eager"
    trust_remote_code: bool = True
    low_cpu_mem_usage: bool = True
    preload_model: bool = False
    api_host: str = "0.0.0.0"
    api_port: int = 7860
    database_path: str = "data/app.db"
    job_workers: int = 1
    max_queued_jobs: int = 8
    max_active_jobs_per_user: int = 2
    require_auth: bool = True
    public_api_enabled: bool = True


DEFAULT_SETTINGS = Settings()


@lru_cache(maxsize=1)
def get_settings() -> Settings:
    take_log = os.getenv("TAKE_LOG", "true").strip().lower() in {"1", "true", "yes", "on"}
    trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "true").strip().lower() in {"1", "true", "yes", "on"}
    low_cpu_mem_usage = os.getenv("LOW_CPU_MEM_USAGE", "true").strip().lower() in {"1", "true", "yes", "on"}
    require_auth = os.getenv("REQUIRE_AUTH", "true").strip().lower() in {"1", "true", "yes", "on"}
    public_api_enabled = os.getenv("PUBLIC_API_ENABLED", "true").strip().lower() in {"1", "true", "yes", "on"}
    return Settings(
        model_name=os.getenv("MODEL_NAME", DEFAULT_SETTINGS.model_name),
        max_trace_tokens=int(os.getenv("MAX_TRACE_TOKENS", DEFAULT_SETTINGS.max_trace_tokens)),
        max_sentences=int(os.getenv("MAX_SENTENCES", DEFAULT_SETTINGS.max_sentences)),
        take_log=take_log,
        device_preference=os.getenv("DEVICE_PREFERENCE", DEFAULT_SETTINGS.device_preference),  # type: ignore[arg-type]
        dtype_preference=os.getenv("DTYPE_PREFERENCE", DEFAULT_SETTINGS.dtype_preference),  # type: ignore[arg-type]
        attn_implementation=os.getenv("ATTN_IMPLEMENTATION", DEFAULT_SETTINGS.attn_implementation),
        trust_remote_code=trust_remote_code,
        low_cpu_mem_usage=low_cpu_mem_usage,
        preload_model=os.getenv("PRELOAD_MODEL", "false").strip().lower() in {"1", "true", "yes", "on"},
        api_host=os.getenv("API_HOST", DEFAULT_SETTINGS.api_host),
        api_port=int(os.getenv("API_PORT", DEFAULT_SETTINGS.api_port)),
        database_path=os.getenv("DATABASE_PATH", DEFAULT_SETTINGS.database_path),
        job_workers=int(os.getenv("JOB_WORKERS", DEFAULT_SETTINGS.job_workers)),
        max_queued_jobs=int(os.getenv("MAX_QUEUED_JOBS", DEFAULT_SETTINGS.max_queued_jobs)),
        max_active_jobs_per_user=int(
            os.getenv("MAX_ACTIVE_JOBS_PER_USER", DEFAULT_SETTINGS.max_active_jobs_per_user)
        ),
        require_auth=require_auth,
        public_api_enabled=public_api_enabled,
    )