GameAI / config.py
j-js's picture
Upload 42 files
e9462cd verified
raw
history blame
2.55 kB
from __future__ import annotations
import os
from dataclasses import dataclass
from pathlib import Path
BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
LOCAL_LOG_DIR = BASE_DIR / "logs"
@dataclass(frozen=True)
class Settings:
app_name: str = os.getenv("APP_NAME", "Trading Game Study AI")
app_version: str = os.getenv("APP_VERSION", "2.0.0")
port: int = int(os.getenv("PORT", "7860"))
embedding_model: str = os.getenv(
"EMBEDDING_MODEL",
"sentence-transformers/all-MiniLM-L6-v2",
)
cross_encoder_model: str = os.getenv(
"CROSS_ENCODER_MODEL",
"cross-encoder/ms-marco-MiniLM-L-6-v2",
)
generator_model: str = os.getenv(
"GENERATOR_MODEL",
"google/flan-t5-small",
)
generator_task: str = os.getenv(
"GENERATOR_TASK",
"text2text-generation",
)
generator_max_new_tokens: int = int(os.getenv("GENERATOR_MAX_NEW_TOKENS", "220"))
generator_temperature: float = float(os.getenv("GENERATOR_TEMPERATURE", "0.6"))
generator_top_p: float = float(os.getenv("GENERATOR_TOP_P", "0.9"))
generator_do_sample: bool = os.getenv("GENERATOR_DO_SAMPLE", "1") == "1"
local_chunks_path: str = os.getenv(
"LOCAL_CHUNKS_PATH",
str(DATA_DIR / "gmat_hf_chunks.jsonl"),
)
question_seed_path: str = os.getenv(
"QUESTION_SEED_PATH",
str(DATA_DIR / "gmat_question_seed.jsonl"),
)
topic_index_path: str = os.getenv(
"TOPIC_INDEX_PATH",
str(DATA_DIR / "gmat_topic_index.json"),
)
dataset_repo_id: str = os.getenv("DATASET_REPO_ID", "j-js/gmat-quant-corpus")
dataset_split: str = os.getenv("DATASET_SPLIT", "train")
retrieval_k: int = int(os.getenv("RETRIEVAL_K", "8"))
rerank_k: int = int(os.getenv("RERANK_K", "4"))
max_chunks_to_show: int = int(os.getenv("MAX_CHUNKS_TO_SHOW", "3"))
max_reply_chars: int = int(os.getenv("MAX_REPLY_CHARS", "1600"))
enable_remote_dataset_fallback: bool = os.getenv("ENABLE_REMOTE_DATASET_FALLBACK", "1") == "1"
local_log_dir: str = os.getenv("LOCAL_LOG_DIR", str(LOCAL_LOG_DIR))
ingest_api_key: str = os.getenv("INGEST_API_KEY", "")
research_api_key: str = os.getenv("RESEARCH_API_KEY", "")
hf_token: str = os.getenv("HF_TOKEN", "")
log_dataset_repo_id: str = os.getenv("LOG_DATASET_REPO_ID", "")
log_dataset_private: bool = os.getenv("LOG_DATASET_PRIVATE", "1") == "1"
push_logs_to_hub: bool = os.getenv("PUSH_LOGS_TO_HUB", "0") == "1"
settings = Settings()