| from __future__ import annotations |
|
|
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
|
|
| BASE_DIR = Path(__file__).resolve().parent |
| DATA_DIR = BASE_DIR / "data" |
| LOCAL_LOG_DIR = BASE_DIR / "logs" |
|
|
|
|
| @dataclass(frozen=True) |
| class Settings: |
| app_name: str = os.getenv("APP_NAME", "Trading Game Study AI") |
| app_version: str = os.getenv("APP_VERSION", "2.0.0") |
| port: int = int(os.getenv("PORT", "7860")) |
|
|
| embedding_model: str = os.getenv( |
| "EMBEDDING_MODEL", |
| "sentence-transformers/all-MiniLM-L6-v2", |
| ) |
| cross_encoder_model: str = os.getenv( |
| "CROSS_ENCODER_MODEL", |
| "cross-encoder/ms-marco-MiniLM-L-6-v2", |
| ) |
| generator_model: str = os.getenv( |
| "GENERATOR_MODEL", |
| "google/flan-t5-small", |
| ) |
| generator_task: str = os.getenv( |
| "GENERATOR_TASK", |
| "text2text-generation", |
| ) |
| generator_max_new_tokens: int = int(os.getenv("GENERATOR_MAX_NEW_TOKENS", "220")) |
| generator_temperature: float = float(os.getenv("GENERATOR_TEMPERATURE", "0.6")) |
| generator_top_p: float = float(os.getenv("GENERATOR_TOP_P", "0.9")) |
| generator_do_sample: bool = os.getenv("GENERATOR_DO_SAMPLE", "1") == "1" |
|
|
| local_chunks_path: str = os.getenv( |
| "LOCAL_CHUNKS_PATH", |
| str(DATA_DIR / "gmat_hf_chunks.jsonl"), |
| ) |
| question_seed_path: str = os.getenv( |
| "QUESTION_SEED_PATH", |
| str(DATA_DIR / "gmat_question_seed.jsonl"), |
| ) |
| topic_index_path: str = os.getenv( |
| "TOPIC_INDEX_PATH", |
| str(DATA_DIR / "gmat_topic_index.json"), |
| ) |
| dataset_repo_id: str = os.getenv("DATASET_REPO_ID", "j-js/gmat-quant-corpus") |
| dataset_split: str = os.getenv("DATASET_SPLIT", "train") |
| retrieval_k: int = int(os.getenv("RETRIEVAL_K", "8")) |
| rerank_k: int = int(os.getenv("RERANK_K", "4")) |
| max_chunks_to_show: int = int(os.getenv("MAX_CHUNKS_TO_SHOW", "3")) |
| max_reply_chars: int = int(os.getenv("MAX_REPLY_CHARS", "1600")) |
| enable_remote_dataset_fallback: bool = os.getenv("ENABLE_REMOTE_DATASET_FALLBACK", "1") == "1" |
|
|
| local_log_dir: str = os.getenv("LOCAL_LOG_DIR", str(LOCAL_LOG_DIR)) |
| ingest_api_key: str = os.getenv("INGEST_API_KEY", "") |
| research_api_key: str = os.getenv("RESEARCH_API_KEY", "") |
| hf_token: str = os.getenv("HF_TOKEN", "") |
| log_dataset_repo_id: str = os.getenv("LOG_DATASET_REPO_ID", "") |
| log_dataset_private: bool = os.getenv("LOG_DATASET_PRIVATE", "1") == "1" |
| push_logs_to_hub: bool = os.getenv("PUSH_LOGS_TO_HUB", "0") == "1" |
|
|
|
|
| settings = Settings() |
|
|