"""Download hosted Warbler packs selectively from a Hugging Face dataset repo.""" import fnmatch import json import logging import math import os from pathlib import Path from typing import Dict, List, Optional from huggingface_hub import HfApi, hf_hub_download, snapshot_download logger = logging.getLogger(__name__) class RemotePackLoader: """Download only the pack files needed for hosted-safe startup.""" def __init__( self, repo_id: str, cache_dir: Optional[str] = None, include_packs: Optional[List[str]] = None, exclude_packs: Optional[List[str]] = None, max_documents_per_pack: Optional[int] = None, token: Optional[str] = None, ): self.repo_id = repo_id self.cache_dir = cache_dir or os.getenv("HF_PACK_CACHE", ".hf_pack_cache") self.include_packs = include_packs or [] self.exclude_packs = exclude_packs or [] self.max_documents_per_pack = max_documents_per_pack self.local_dir: Optional[Path] = None self.api = HfApi(token=token) @classmethod def from_environment(cls, repo_id: str): """Create a remote loader configured with the same hosted-safe defaults as PackLoader.""" include_packs = cls._split_csv_env("WARBLER_INCLUDE_PACKS") exclude_packs = cls._split_csv_env("WARBLER_EXCLUDE_PACKS") max_documents_per_pack = cls._parse_int_env("WARBLER_MAX_DOCUMENTS_PER_PACK") if cls._is_hosted_environment(): if not exclude_packs: exclude_packs = ["warbler-pack-hf-tinystories"] if max_documents_per_pack is None: max_documents_per_pack = 5000 return cls( repo_id=repo_id, include_packs=include_packs, exclude_packs=exclude_packs, max_documents_per_pack=max_documents_per_pack, token=os.getenv("HF_TOKEN"), ) @staticmethod def _is_hosted_environment() -> bool: hosted_flag = os.getenv("WARBLER_HOSTED_MODE", "").lower() return hosted_flag in {"1", "true", "yes", "on"} or bool( os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID") ) @staticmethod def _split_csv_env(name: str) -> List[str]: raw_value = os.getenv(name, "") return [part.strip() for part in raw_value.split(",") if part.strip()] @staticmethod def _parse_int_env(name: str) -> Optional[int]: raw_value = os.getenv(name) if raw_value in (None, ""): return None try: return int(raw_value) except ValueError: logger.warning("Ignoring invalid integer for %s: %s", name, raw_value) return None def _should_load_pack(self, pack_name: str) -> bool: if self.include_packs: included = any(fnmatch.fnmatch(pack_name, pattern) for pattern in self.include_packs) if not included: return False if self.exclude_packs: excluded = any(fnmatch.fnmatch(pack_name, pattern) for pattern in self.exclude_packs) if excluded: return False return True def _list_repo_files(self) -> List[str]: return self.api.list_repo_files(repo_id=self.repo_id, repo_type="dataset") def _load_pack_metadata(self, pack_name: str) -> Dict[str, object]: metadata_path = f"packs/{pack_name}/package.json" try: downloaded_path = hf_hub_download( repo_id=self.repo_id, repo_type="dataset", filename=metadata_path, cache_dir=self.cache_dir, token=os.getenv("HF_TOKEN"), ) except Exception: return {} try: return json.loads(Path(downloaded_path).read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): logger.warning("Failed to parse remote metadata for %s", pack_name) return {} def build_allow_patterns(self) -> List[str]: """Build a minimal file set for the selected remote packs.""" repo_files = self._list_repo_files() pack_files: Dict[str, List[str]] = {} for repo_file in repo_files: parts = Path(repo_file).parts if len(parts) < 3 or parts[0] != "packs": continue pack_name = parts[1] if not self._should_load_pack(pack_name): continue pack_files.setdefault(pack_name, []).append(repo_file) allow_patterns: List[str] = [] for pack_name in sorted(pack_files): files = sorted(pack_files[pack_name]) metadata = self._load_pack_metadata(pack_name) package_json = f"packs/{pack_name}/package.json" if package_json in files: allow_patterns.append(package_json) templates_path = f"packs/{pack_name}/pack/templates.json" if templates_path in files: allow_patterns.append(templates_path) jsonl_files = [path for path in files if path.endswith(".jsonl")] if metadata.get("chunked"): docs_per_chunk = metadata.get("docs_per_chunk") chunk_limit = None if self.max_documents_per_pack and isinstance(docs_per_chunk, int) and docs_per_chunk > 0: chunk_limit = max(1, math.ceil(self.max_documents_per_pack / docs_per_chunk)) elif self.max_documents_per_pack: chunk_limit = 1 selected_jsonl = jsonl_files[:chunk_limit] if chunk_limit is not None else jsonl_files else: preferred_jsonl = f"packs/{pack_name}/{pack_name}.jsonl" selected_jsonl = [preferred_jsonl] if preferred_jsonl in jsonl_files else jsonl_files[:1] allow_patterns.extend(selected_jsonl) return allow_patterns def fetch_packs(self) -> Path: """Download only the selected pack files from the remote dataset repo.""" allow_patterns = self.build_allow_patterns() if not allow_patterns: raise RuntimeError(f"No remote pack files selected for repo {self.repo_id}") self.local_dir = Path( snapshot_download( repo_id=self.repo_id, repo_type="dataset", cache_dir=self.cache_dir, allow_patterns=allow_patterns, local_files_only=False, token=os.getenv("HF_TOKEN"), ) ) return self.local_dir / "packs" def get_local_packs_dir(self) -> Path: if self.local_dir is None: raise RuntimeError("fetch_packs() must be called first") return self.local_dir / "packs"