warbler-cda / warbler_cda /remote_pack_loader.py
Bellok
feat: Add support for remote pack loading with environment configuration and enhanced metadata handling
1c68bde
"""Download hosted Warbler packs selectively from a Hugging Face dataset repo."""
import fnmatch
import json
import logging
import math
import os
from pathlib import Path
from typing import Dict, List, Optional
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
logger = logging.getLogger(__name__)
class RemotePackLoader:
"""Download only the pack files needed for hosted-safe startup."""
def __init__(
self,
repo_id: str,
cache_dir: Optional[str] = None,
include_packs: Optional[List[str]] = None,
exclude_packs: Optional[List[str]] = None,
max_documents_per_pack: Optional[int] = None,
token: Optional[str] = None,
):
self.repo_id = repo_id
self.cache_dir = cache_dir or os.getenv("HF_PACK_CACHE", ".hf_pack_cache")
self.include_packs = include_packs or []
self.exclude_packs = exclude_packs or []
self.max_documents_per_pack = max_documents_per_pack
self.local_dir: Optional[Path] = None
self.api = HfApi(token=token)
@classmethod
def from_environment(cls, repo_id: str):
"""Create a remote loader configured with the same hosted-safe defaults as PackLoader."""
include_packs = cls._split_csv_env("WARBLER_INCLUDE_PACKS")
exclude_packs = cls._split_csv_env("WARBLER_EXCLUDE_PACKS")
max_documents_per_pack = cls._parse_int_env("WARBLER_MAX_DOCUMENTS_PER_PACK")
if cls._is_hosted_environment():
if not exclude_packs:
exclude_packs = ["warbler-pack-hf-tinystories"]
if max_documents_per_pack is None:
max_documents_per_pack = 5000
return cls(
repo_id=repo_id,
include_packs=include_packs,
exclude_packs=exclude_packs,
max_documents_per_pack=max_documents_per_pack,
token=os.getenv("HF_TOKEN"),
)
@staticmethod
def _is_hosted_environment() -> bool:
hosted_flag = os.getenv("WARBLER_HOSTED_MODE", "").lower()
return hosted_flag in {"1", "true", "yes", "on"} or bool(
os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID")
)
@staticmethod
def _split_csv_env(name: str) -> List[str]:
raw_value = os.getenv(name, "")
return [part.strip() for part in raw_value.split(",") if part.strip()]
@staticmethod
def _parse_int_env(name: str) -> Optional[int]:
raw_value = os.getenv(name)
if raw_value in (None, ""):
return None
try:
return int(raw_value)
except ValueError:
logger.warning("Ignoring invalid integer for %s: %s", name, raw_value)
return None
def _should_load_pack(self, pack_name: str) -> bool:
if self.include_packs:
included = any(fnmatch.fnmatch(pack_name, pattern) for pattern in self.include_packs)
if not included:
return False
if self.exclude_packs:
excluded = any(fnmatch.fnmatch(pack_name, pattern) for pattern in self.exclude_packs)
if excluded:
return False
return True
def _list_repo_files(self) -> List[str]:
return self.api.list_repo_files(repo_id=self.repo_id, repo_type="dataset")
def _load_pack_metadata(self, pack_name: str) -> Dict[str, object]:
metadata_path = f"packs/{pack_name}/package.json"
try:
downloaded_path = hf_hub_download(
repo_id=self.repo_id,
repo_type="dataset",
filename=metadata_path,
cache_dir=self.cache_dir,
token=os.getenv("HF_TOKEN"),
)
except Exception:
return {}
try:
return json.loads(Path(downloaded_path).read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
logger.warning("Failed to parse remote metadata for %s", pack_name)
return {}
def build_allow_patterns(self) -> List[str]:
"""Build a minimal file set for the selected remote packs."""
repo_files = self._list_repo_files()
pack_files: Dict[str, List[str]] = {}
for repo_file in repo_files:
parts = Path(repo_file).parts
if len(parts) < 3 or parts[0] != "packs":
continue
pack_name = parts[1]
if not self._should_load_pack(pack_name):
continue
pack_files.setdefault(pack_name, []).append(repo_file)
allow_patterns: List[str] = []
for pack_name in sorted(pack_files):
files = sorted(pack_files[pack_name])
metadata = self._load_pack_metadata(pack_name)
package_json = f"packs/{pack_name}/package.json"
if package_json in files:
allow_patterns.append(package_json)
templates_path = f"packs/{pack_name}/pack/templates.json"
if templates_path in files:
allow_patterns.append(templates_path)
jsonl_files = [path for path in files if path.endswith(".jsonl")]
if metadata.get("chunked"):
docs_per_chunk = metadata.get("docs_per_chunk")
chunk_limit = None
if self.max_documents_per_pack and isinstance(docs_per_chunk, int) and docs_per_chunk > 0:
chunk_limit = max(1, math.ceil(self.max_documents_per_pack / docs_per_chunk))
elif self.max_documents_per_pack:
chunk_limit = 1
selected_jsonl = jsonl_files[:chunk_limit] if chunk_limit is not None else jsonl_files
else:
preferred_jsonl = f"packs/{pack_name}/{pack_name}.jsonl"
selected_jsonl = [preferred_jsonl] if preferred_jsonl in jsonl_files else jsonl_files[:1]
allow_patterns.extend(selected_jsonl)
return allow_patterns
def fetch_packs(self) -> Path:
"""Download only the selected pack files from the remote dataset repo."""
allow_patterns = self.build_allow_patterns()
if not allow_patterns:
raise RuntimeError(f"No remote pack files selected for repo {self.repo_id}")
self.local_dir = Path(
snapshot_download(
repo_id=self.repo_id,
repo_type="dataset",
cache_dir=self.cache_dir,
allow_patterns=allow_patterns,
local_files_only=False,
token=os.getenv("HF_TOKEN"),
)
)
return self.local_dir / "packs"
def get_local_packs_dir(self) -> Path:
if self.local_dir is None:
raise RuntimeError("fetch_packs() must be called first")
return self.local_dir / "packs"