File size: 2,597 Bytes
e9462cd
 
 
 
 
 
 
 
 
 
 
7834040
 
 
 
e9462cd
 
7834040
e9462cd
 
 
 
7834040
 
 
 
 
e9462cd
 
 
7834040
e9462cd
7834040
 
 
 
 
 
e9462cd
 
 
 
 
 
7834040
e9462cd
7834040
e9462cd
7834040
 
 
 
 
 
e9462cd
 
 
 
7834040
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from __future__ import annotations

import os
from dataclasses import dataclass
from pathlib import Path

BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
LOCAL_LOG_DIR = BASE_DIR / "logs"


def env_bool(name: str, default: bool = False) -> bool:
    return os.getenv(name, "1" if default else "0").strip().lower() in {"1", "true", "yes", "on"}


@dataclass(frozen=True)
class Settings:
    # App
    app_name: str = os.getenv("APP_NAME", "Trading Game Study AI")
    app_version: str = os.getenv("APP_VERSION", "2.0.0")
    port: int = int(os.getenv("PORT", "7860"))

    # Models
    embedding_model: str = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
    cross_encoder_model: str = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
    generator_model: str = os.getenv("GENERATOR_MODEL", "google/flan-t5-small")
    generator_task: str = os.getenv("GENERATOR_TASK", "text2text-generation")
    generator_max_new_tokens: int = int(os.getenv("GENERATOR_MAX_NEW_TOKENS", "220"))
    generator_temperature: float = float(os.getenv("GENERATOR_TEMPERATURE", "0.6"))
    generator_top_p: float = float(os.getenv("GENERATOR_TOP_P", "0.9"))
    generator_do_sample: bool = env_bool("GENERATOR_DO_SAMPLE", True)

    # Local data
    local_chunks_path: str = os.getenv("LOCAL_CHUNKS_PATH", str(DATA_DIR / "gmat_hf_chunks.jsonl"))
    question_seed_path: str = os.getenv("QUESTION_SEED_PATH", str(DATA_DIR / "gmat_question_seed.jsonl"))
    topic_index_path: str = os.getenv("TOPIC_INDEX_PATH", str(DATA_DIR / "gmat_topic_index.json"))

    # Retrieval
    dataset_repo_id: str = os.getenv("DATASET_REPO_ID", "j-js/gmat-quant-corpus")
    dataset_split: str = os.getenv("DATASET_SPLIT", "train")
    retrieval_k: int = int(os.getenv("RETRIEVAL_K", "8"))
    rerank_k: int = int(os.getenv("RERANK_K", "4"))
    max_chunks_to_show: int = int(os.getenv("MAX_CHUNKS_TO_SHOW", "3"))
    max_reply_chars: int = int(os.getenv("MAX_REPLY_CHARS", "1600"))
    enable_remote_dataset_fallback: bool = env_bool("ENABLE_REMOTE_DATASET_FALLBACK", True)

    # Logging
    local_log_dir: str = os.getenv("LOCAL_LOG_DIR", str(LOCAL_LOG_DIR))
    push_logs_to_hub: bool = env_bool("PUSH_LOGS_TO_HUB", False)
    log_dataset_repo_id: str = os.getenv("LOG_DATASET_REPO_ID", "")
    log_dataset_private: bool = env_bool("LOG_DATASET_PRIVATE", True)

    # Secrets
    hf_token: str = os.getenv("HF_TOKEN", "")
    ingest_api_key: str = os.getenv("INGEST_API_KEY", "")
    research_api_key: str = os.getenv("RESEARCH_API_KEY", "")


settings = Settings()