from pathlib import Path from typing import Union, Dict import os from huggingface_hub import hf_hub_download # --- Constants --- HF_REPO = "mickey1976/mayankc-amazon_beauty_subset" CACHE: Dict[str, Path] = {} # --- project roots --- PROJECT_ROOT = Path(__file__).resolve().parents[2] DATA_DIR = PROJECT_ROOT / "data" RAW_DIR = DATA_DIR / "raw" PROCESSED_DIR = DATA_DIR / "processed" CACHE_DIR = DATA_DIR / "cache" LOGS_DIR = PROJECT_ROOT / "logs" MODELS_DIR = PROJECT_ROOT / "src" / "models" def ensure_dir(path: Union[str, Path]) -> Path: p = Path(path) if not isinstance(path, Path) else path p.mkdir(parents=True, exist_ok=True) return p def get_raw_path(dataset: str) -> Path: return ensure_dir(RAW_DIR / dataset) def _hf_download(filename: str) -> Path: if filename in CACHE: return CACHE[filename] path = hf_hub_download(repo_id=HF_REPO, filename=filename, repo_type="dataset") CACHE[filename] = Path(path) return Path(path) def get_processed_path(dataset: str) -> Path: local_path = PROCESSED_DIR / dataset if local_path.exists(): return local_path # fallback: download any known file to get a valid parent path fallback_file = f"parquet/user_text_emb.parquet" fallback_path = _hf_download(fallback_file) return fallback_path.parent def get_logs_path() -> Path: return ensure_dir(LOGS_DIR) def get_dataset_paths(dataset: str) -> Dict[str, Path]: dataset = dataset.lower() def resolve_or_download(subfolder: str, name: str) -> Path: local = PROCESSED_DIR / dataset / name if local.exists(): return local return _hf_download(f"{subfolder}/{name}") return { "raw": get_raw_path(dataset), "processed": get_processed_path(dataset), "cache": ensure_dir(CACHE_DIR / dataset), "logs": get_logs_path(), # JSON and config files "defaults": resolve_or_download("json", "defaults.json"), "item_ids": resolve_or_download("json", "item_ids.json"), "user_seq": resolve_or_download("json", "user_seq.json"), # Parquet files "item_meta_emb": resolve_or_download("parquet", "item_meta_emb.parquet"), "item_image_emb": resolve_or_download("parquet", "item_image_emb.parquet"), "item_text_emb": resolve_or_download("parquet", "item_text_emb.parquet"), "user_text_emb": resolve_or_download("parquet", "user_text_emb.parquet"), # NPY files "text": resolve_or_download("npy", "text.npy"), "image": resolve_or_download("npy", "image.npy"), "meta": resolve_or_download("npy", "meta.npy"), "cove": resolve_or_download("npy", "cove.npy"), # FAISS files "faiss_concat": resolve_or_download("faiss", "items_beauty_concat.faiss"), "faiss_weighted": resolve_or_download("faiss", "items_beauty_weighted.faiss"), # Model "adapter_model": resolve_or_download("model", "adapter_model.safetensors"), "full_model": resolve_or_download("model", "model.safetensors"), }