Megumin-chat / megumin_agent /bootstrap.py
Junhoee's picture
Upload 6 files
06b2015 verified
from __future__ import annotations
import os
import sys
from pathlib import Path
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
PROJECT_ROOT = Path(__file__).resolve().parents[1]
ADK_SRC = PROJECT_ROOT / "adk-python" / "src"
LOCAL_DATASET_DIR = PROJECT_ROOT / "data" / "processed"
RUNTIME_DATASET_DIR = PROJECT_ROOT / "data" / "_runtime_processed"
def _dataset_repo_id() -> str:
return os.getenv("MEGUMIN_HF_DATASET_REPO_ID", "Junhoee/megumin-chat")
def _dataset_filename() -> str:
return os.getenv("MEGUMIN_HF_DATASET_FILENAME", "megumin_qa_dataset.json")
def _index_filename() -> str:
return os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
def _qa_index_filename() -> str:
return os.getenv("MEGUMIN_FAISS_QA_INDEX_FILENAME", "megumin_question_answer.faiss")
def _metadata_filename() -> str:
return os.getenv("MEGUMIN_FAISS_METADATA_FILENAME", "megumin_questions_meta.json")
def _fact_dataset_filename() -> str:
return os.getenv("MEGUMIN_HF_FACT_DATASET_FILENAME", "namuwiki_qa.json")
def _fact_index_filename() -> str:
return os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss")
def _fact_qa_index_filename() -> str:
return os.getenv("MEGUMIN_HF_FACT_QA_INDEX_FILENAME", "namuwiki_question_answer.faiss")
def _fact_metadata_filename() -> str:
return os.getenv("MEGUMIN_HF_FACT_METADATA_FILENAME", "namuwiki_questions_meta.json")
def bootstrap_environment() -> None:
load_dotenv(PROJECT_ROOT / ".env", override=True)
if ADK_SRC.exists():
adk_src = str(ADK_SRC)
if adk_src not in sys.path:
sys.path.insert(0, adk_src)
def resolve_dataset_dir() -> Path:
RUNTIME_DATASET_DIR.mkdir(parents=True, exist_ok=True)
try:
hf_token = os.getenv("HF_TOKEN") or None
repo_id = _dataset_repo_id()
artifact_names = (
_dataset_filename(),
_index_filename(),
_qa_index_filename(),
_metadata_filename(),
_fact_dataset_filename(),
_fact_index_filename(),
_fact_qa_index_filename(),
_fact_metadata_filename(),
)
for artifact_name in artifact_names:
try:
hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=artifact_name,
token=hf_token,
local_dir=str(RUNTIME_DATASET_DIR),
)
except Exception:
if artifact_name not in {_dataset_filename(), _fact_dataset_filename()}:
continue
raise
return RUNTIME_DATASET_DIR
except Exception:
if LOCAL_DATASET_DIR.exists() and any(LOCAL_DATASET_DIR.glob("*.json")):
return LOCAL_DATASET_DIR
raise