Spaces:
Sleeping
Sleeping
Bellok
feat: Add support for remote pack loading with environment configuration and enhanced metadata handling
1c68bde | """Download hosted Warbler packs selectively from a Hugging Face dataset repo.""" | |
| import fnmatch | |
| import json | |
| import logging | |
| import math | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| from huggingface_hub import HfApi, hf_hub_download, snapshot_download | |
| logger = logging.getLogger(__name__) | |
| class RemotePackLoader: | |
| """Download only the pack files needed for hosted-safe startup.""" | |
| def __init__( | |
| self, | |
| repo_id: str, | |
| cache_dir: Optional[str] = None, | |
| include_packs: Optional[List[str]] = None, | |
| exclude_packs: Optional[List[str]] = None, | |
| max_documents_per_pack: Optional[int] = None, | |
| token: Optional[str] = None, | |
| ): | |
| self.repo_id = repo_id | |
| self.cache_dir = cache_dir or os.getenv("HF_PACK_CACHE", ".hf_pack_cache") | |
| self.include_packs = include_packs or [] | |
| self.exclude_packs = exclude_packs or [] | |
| self.max_documents_per_pack = max_documents_per_pack | |
| self.local_dir: Optional[Path] = None | |
| self.api = HfApi(token=token) | |
| def from_environment(cls, repo_id: str): | |
| """Create a remote loader configured with the same hosted-safe defaults as PackLoader.""" | |
| include_packs = cls._split_csv_env("WARBLER_INCLUDE_PACKS") | |
| exclude_packs = cls._split_csv_env("WARBLER_EXCLUDE_PACKS") | |
| max_documents_per_pack = cls._parse_int_env("WARBLER_MAX_DOCUMENTS_PER_PACK") | |
| if cls._is_hosted_environment(): | |
| if not exclude_packs: | |
| exclude_packs = ["warbler-pack-hf-tinystories"] | |
| if max_documents_per_pack is None: | |
| max_documents_per_pack = 5000 | |
| return cls( | |
| repo_id=repo_id, | |
| include_packs=include_packs, | |
| exclude_packs=exclude_packs, | |
| max_documents_per_pack=max_documents_per_pack, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| def _is_hosted_environment() -> bool: | |
| hosted_flag = os.getenv("WARBLER_HOSTED_MODE", "").lower() | |
| return hosted_flag in {"1", "true", "yes", "on"} or bool( | |
| os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID") | |
| ) | |
| def _split_csv_env(name: str) -> List[str]: | |
| raw_value = os.getenv(name, "") | |
| return [part.strip() for part in raw_value.split(",") if part.strip()] | |
| def _parse_int_env(name: str) -> Optional[int]: | |
| raw_value = os.getenv(name) | |
| if raw_value in (None, ""): | |
| return None | |
| try: | |
| return int(raw_value) | |
| except ValueError: | |
| logger.warning("Ignoring invalid integer for %s: %s", name, raw_value) | |
| return None | |
| def _should_load_pack(self, pack_name: str) -> bool: | |
| if self.include_packs: | |
| included = any(fnmatch.fnmatch(pack_name, pattern) for pattern in self.include_packs) | |
| if not included: | |
| return False | |
| if self.exclude_packs: | |
| excluded = any(fnmatch.fnmatch(pack_name, pattern) for pattern in self.exclude_packs) | |
| if excluded: | |
| return False | |
| return True | |
| def _list_repo_files(self) -> List[str]: | |
| return self.api.list_repo_files(repo_id=self.repo_id, repo_type="dataset") | |
| def _load_pack_metadata(self, pack_name: str) -> Dict[str, object]: | |
| metadata_path = f"packs/{pack_name}/package.json" | |
| try: | |
| downloaded_path = hf_hub_download( | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| filename=metadata_path, | |
| cache_dir=self.cache_dir, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| except Exception: | |
| return {} | |
| try: | |
| return json.loads(Path(downloaded_path).read_text(encoding="utf-8")) | |
| except (OSError, json.JSONDecodeError): | |
| logger.warning("Failed to parse remote metadata for %s", pack_name) | |
| return {} | |
| def build_allow_patterns(self) -> List[str]: | |
| """Build a minimal file set for the selected remote packs.""" | |
| repo_files = self._list_repo_files() | |
| pack_files: Dict[str, List[str]] = {} | |
| for repo_file in repo_files: | |
| parts = Path(repo_file).parts | |
| if len(parts) < 3 or parts[0] != "packs": | |
| continue | |
| pack_name = parts[1] | |
| if not self._should_load_pack(pack_name): | |
| continue | |
| pack_files.setdefault(pack_name, []).append(repo_file) | |
| allow_patterns: List[str] = [] | |
| for pack_name in sorted(pack_files): | |
| files = sorted(pack_files[pack_name]) | |
| metadata = self._load_pack_metadata(pack_name) | |
| package_json = f"packs/{pack_name}/package.json" | |
| if package_json in files: | |
| allow_patterns.append(package_json) | |
| templates_path = f"packs/{pack_name}/pack/templates.json" | |
| if templates_path in files: | |
| allow_patterns.append(templates_path) | |
| jsonl_files = [path for path in files if path.endswith(".jsonl")] | |
| if metadata.get("chunked"): | |
| docs_per_chunk = metadata.get("docs_per_chunk") | |
| chunk_limit = None | |
| if self.max_documents_per_pack and isinstance(docs_per_chunk, int) and docs_per_chunk > 0: | |
| chunk_limit = max(1, math.ceil(self.max_documents_per_pack / docs_per_chunk)) | |
| elif self.max_documents_per_pack: | |
| chunk_limit = 1 | |
| selected_jsonl = jsonl_files[:chunk_limit] if chunk_limit is not None else jsonl_files | |
| else: | |
| preferred_jsonl = f"packs/{pack_name}/{pack_name}.jsonl" | |
| selected_jsonl = [preferred_jsonl] if preferred_jsonl in jsonl_files else jsonl_files[:1] | |
| allow_patterns.extend(selected_jsonl) | |
| return allow_patterns | |
| def fetch_packs(self) -> Path: | |
| """Download only the selected pack files from the remote dataset repo.""" | |
| allow_patterns = self.build_allow_patterns() | |
| if not allow_patterns: | |
| raise RuntimeError(f"No remote pack files selected for repo {self.repo_id}") | |
| self.local_dir = Path( | |
| snapshot_download( | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| cache_dir=self.cache_dir, | |
| allow_patterns=allow_patterns, | |
| local_files_only=False, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| ) | |
| return self.local_dir / "packs" | |
| def get_local_packs_dir(self) -> Path: | |
| if self.local_dir is None: | |
| raise RuntimeError("fetch_packs() must be called first") | |
| return self.local_dir / "packs" | |