Junhoee commited on
Commit
3fafd7c
·
verified ·
1 Parent(s): c40a1c4

Update megumin_agent/bootstrap.py

Browse files
Files changed (1) hide show
  1. megumin_agent/bootstrap.py +36 -10
megumin_agent/bootstrap.py CHANGED
@@ -12,12 +12,26 @@ PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
  ADK_SRC = PROJECT_ROOT / "adk-python" / "src"
13
  LOCAL_DATASET_DIR = PROJECT_ROOT / "data" / "processed"
14
  RUNTIME_DATASET_DIR = PROJECT_ROOT / "data" / "_runtime_processed"
15
- HF_DATASET_REPO_ID = os.getenv("MEGUMIN_HF_DATASET_REPO_ID", "Junhoee/megumin-chat")
16
- HF_DATASET_FILENAME = os.getenv("MEGUMIN_HF_DATASET_FILENAME", "megumin_qa_dataset.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def bootstrap_environment() -> None:
20
- load_dotenv(PROJECT_ROOT / ".env")
21
  if ADK_SRC.exists():
22
  adk_src = str(ADK_SRC)
23
  if adk_src not in sys.path:
@@ -29,14 +43,26 @@ def resolve_dataset_dir() -> Path:
29
 
30
  try:
31
  hf_token = os.getenv("HF_TOKEN") or None
32
- downloaded_path = hf_hub_download(
33
- repo_id=os.getenv("MEGUMIN_HF_DATASET_REPO_ID", HF_DATASET_REPO_ID),
34
- repo_type="dataset",
35
- filename=os.getenv("MEGUMIN_HF_DATASET_FILENAME", HF_DATASET_FILENAME),
36
- token=hf_token,
37
- local_dir=str(RUNTIME_DATASET_DIR),
38
  )
39
- return Path(downloaded_path).parent
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception:
41
  if LOCAL_DATASET_DIR.exists() and any(LOCAL_DATASET_DIR.glob("*.json")):
42
  return LOCAL_DATASET_DIR
 
12
  ADK_SRC = PROJECT_ROOT / "adk-python" / "src"
13
  LOCAL_DATASET_DIR = PROJECT_ROOT / "data" / "processed"
14
  RUNTIME_DATASET_DIR = PROJECT_ROOT / "data" / "_runtime_processed"
15
+
16
+
17
+ def _dataset_repo_id() -> str:
18
+ return os.getenv("MEGUMIN_HF_DATASET_REPO_ID", "Junhoee/megumin-chat")
19
+
20
+
21
+ def _dataset_filename() -> str:
22
+ return os.getenv("MEGUMIN_HF_DATASET_FILENAME", "megumin_qa_dataset.json")
23
+
24
+
25
+ def _index_filename() -> str:
26
+ return os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
27
+
28
+
29
+ def _metadata_filename() -> str:
30
+ return os.getenv("MEGUMIN_FAISS_METADATA_FILENAME", "megumin_questions_meta.json")
31
 
32
 
33
  def bootstrap_environment() -> None:
34
+ load_dotenv(PROJECT_ROOT / ".env", override=True)
35
  if ADK_SRC.exists():
36
  adk_src = str(ADK_SRC)
37
  if adk_src not in sys.path:
 
43
 
44
  try:
45
  hf_token = os.getenv("HF_TOKEN") or None
46
+ repo_id = _dataset_repo_id()
47
+ artifact_names = (
48
+ _dataset_filename(),
49
+ _index_filename(),
50
+ _metadata_filename(),
 
51
  )
52
+ for artifact_name in artifact_names:
53
+ try:
54
+ hf_hub_download(
55
+ repo_id=repo_id,
56
+ repo_type="dataset",
57
+ filename=artifact_name,
58
+ token=hf_token,
59
+ local_dir=str(RUNTIME_DATASET_DIR),
60
+ )
61
+ except Exception:
62
+ if artifact_name != _dataset_filename():
63
+ continue
64
+ raise
65
+ return RUNTIME_DATASET_DIR
66
  except Exception:
67
  if LOCAL_DATASET_DIR.exists() and any(LOCAL_DATASET_DIR.glob("*.json")):
68
  return LOCAL_DATASET_DIR