File size: 2,881 Bytes
7245599
 
60d64f7
06b2015
7245599
 
06b2015
 
7245599
 
06b2015
 
 
 
3fafd7c
 
06b2015
 
3fafd7c
 
06b2015
 
3fafd7c
 
06b2015
 
3fafd7c
 
06b2015
 
4486ffb
 
06b2015
 
7245599
 
06b2015
 
021a05e
 
06b2015
 
021a05e
 
06b2015
 
4486ffb
 
06b2015
 
021a05e
 
06b2015
 
 
 
 
 
60d64f7
 
06b2015
 
60d64f7
 
06b2015
 
 
 
 
 
 
 
 
 
 
60d64f7
06b2015
 
 
 
 
 
 
 
 
 
 
3fafd7c
06b2015
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations

import os
import sys
from pathlib import Path

from dotenv import load_dotenv
from huggingface_hub import hf_hub_download


PROJECT_ROOT = Path(__file__).resolve().parents[1]
ADK_SRC = PROJECT_ROOT / "adk-python" / "src"
LOCAL_DATASET_DIR = PROJECT_ROOT / "data" / "processed"
RUNTIME_DATASET_DIR = PROJECT_ROOT / "data" / "_runtime_processed"


def _dataset_repo_id() -> str:
    return os.getenv("MEGUMIN_HF_DATASET_REPO_ID", "Junhoee/megumin-chat")


def _dataset_filename() -> str:
    return os.getenv("MEGUMIN_HF_DATASET_FILENAME", "megumin_qa_dataset.json")


def _index_filename() -> str:
    return os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")


def _qa_index_filename() -> str:
    return os.getenv("MEGUMIN_FAISS_QA_INDEX_FILENAME", "megumin_question_answer.faiss")


def _metadata_filename() -> str:
    return os.getenv("MEGUMIN_FAISS_METADATA_FILENAME", "megumin_questions_meta.json")


def _fact_dataset_filename() -> str:
    return os.getenv("MEGUMIN_HF_FACT_DATASET_FILENAME", "namuwiki_qa.json")


def _fact_index_filename() -> str:
    return os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss")


def _fact_qa_index_filename() -> str:
    return os.getenv("MEGUMIN_HF_FACT_QA_INDEX_FILENAME", "namuwiki_question_answer.faiss")


def _fact_metadata_filename() -> str:
    return os.getenv("MEGUMIN_HF_FACT_METADATA_FILENAME", "namuwiki_questions_meta.json")


def bootstrap_environment() -> None:
    load_dotenv(PROJECT_ROOT / ".env", override=True)
    if ADK_SRC.exists():
        adk_src = str(ADK_SRC)
        if adk_src not in sys.path:
            sys.path.insert(0, adk_src)


def resolve_dataset_dir() -> Path:
    RUNTIME_DATASET_DIR.mkdir(parents=True, exist_ok=True)

    try:
        hf_token = os.getenv("HF_TOKEN") or None
        repo_id = _dataset_repo_id()
        artifact_names = (
            _dataset_filename(),
            _index_filename(),
            _qa_index_filename(),
            _metadata_filename(),
            _fact_dataset_filename(),
            _fact_index_filename(),
            _fact_qa_index_filename(),
            _fact_metadata_filename(),
        )
        for artifact_name in artifact_names:
            try:
                hf_hub_download(
                    repo_id=repo_id,
                    repo_type="dataset",
                    filename=artifact_name,
                    token=hf_token,
                    local_dir=str(RUNTIME_DATASET_DIR),
                )
            except Exception:
                if artifact_name not in {_dataset_filename(), _fact_dataset_filename()}:
                    continue
                raise
        return RUNTIME_DATASET_DIR
    except Exception:
        if LOCAL_DATASET_DIR.exists() and any(LOCAL_DATASET_DIR.glob("*.json")):
            return LOCAL_DATASET_DIR
        raise