""" Artifact Management =================== Handles uploading trained models and results to HuggingFace Hub, and downloading them for evaluation or transfer learning. Supports both single-task models (MLP, CNN) and multi-task models (MTL). The unified V1 repo (ascad-v1-models) uses the following structure: - Single-byte: desync{N}/{model_type}/byte{X}/ (model.h5, results.json, rank_curve.npy) - Multi-task: desync{N}/{variant}/ (model.h5, results.json, rank_curve_byte{0..15}.npy) """ import json import logging import os from typing import Dict, List, Optional from huggingface_hub import HfApi, hf_hub_download from .constants import HF_MLP_REPO, HF_CNN_REPO, HF_MTAN_REPO, HF_V1_REPO logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Repository mapping (legacy) # --------------------------------------------------------------------------- def _get_legacy_repo_id(model_type: str) -> str: """Map model type to legacy HuggingFace repository ID.""" repos = {"mlp": HF_MLP_REPO, "cnn": HF_CNN_REPO, "mtan": HF_MTAN_REPO} if model_type not in repos: raise ValueError(f"Unknown model type: {model_type}") return repos[model_type] def _build_hf_path(model_type: str, target_byte: int, desync: int) -> str: """Build the HuggingFace path prefix for a single-task model.""" return f"desync{desync}/{model_type}/byte{target_byte}" # --------------------------------------------------------------------------- # Unified V1 upload helpers # --------------------------------------------------------------------------- def _ensure_v1_repo(api: HfApi) -> None: """Ensure the V1 repo exists.""" try: api.create_repo( repo_id=HF_V1_REPO, repo_type="model", exist_ok=True, private=False, ) except Exception as e: logger.warning("Could not create/verify repo %s: %s", HF_V1_REPO, e) def _upload_files( api: HfApi, repo_id: str, local_dir: str, hf_prefix: str, filenames: List[str], ) -> int: """Upload a list of files to a HuggingFace repo. Returns count of uploaded files.""" uploaded = 0 for filename in filenames: local_path = os.path.join(local_dir, filename) if not os.path.isfile(local_path): logger.warning("File not found, skipping: %s", local_path) continue hf_path = f"{hf_prefix}/{filename}" try: logger.info("Uploading %s -> %s/%s", local_path, repo_id, hf_path) api.upload_file( path_or_fileobj=local_path, path_in_repo=hf_path, repo_id=repo_id, repo_type="model", ) uploaded += 1 except Exception as e: logger.error("Failed to upload %s: %s", filename, e) return uploaded # --------------------------------------------------------------------------- # Single-task upload (MLP, CNN) # --------------------------------------------------------------------------- def upload_model( model_dir: str, model_type: str, target_byte: int, desync: int, repo_id: Optional[str] = None, failed: bool = False, ) -> None: """ Upload a trained single-task model directory to HuggingFace Hub. Uploads to both the unified V1 repo and the legacy repo (if no override). Args: model_dir: Local directory containing model.h5, results.json, rank_curve.npy. model_type: 'mlp' or 'cnn'. target_byte: Target byte index (0-15). desync: Desynchronization level (0, 50, or 100). repo_id: Override the default HuggingFace repo ID (skips V1 upload). """ hf_prefix = _build_hf_path(model_type, target_byte, desync) if failed: hf_prefix = f"failed/{hf_prefix}" files = ["model.h5", "results.json", "rank_curve.npy"] api = HfApi() # Always upload to V1 repo _ensure_v1_repo(api) count = _upload_files(api, HF_V1_REPO, model_dir, hf_prefix, files) logger.info( "V1 upload complete: %s byte=%d desync=%d -> %s (%d files)", model_type, target_byte, desync, HF_V1_REPO, count, ) # Also upload to legacy repo if no override if repo_id is None: legacy_repo = _get_legacy_repo_id(model_type) else: legacy_repo = repo_id count = _upload_files(api, legacy_repo, model_dir, hf_prefix, files) logger.info( "Legacy upload complete: %s byte=%d desync=%d -> %s (%d files)", model_type, target_byte, desync, legacy_repo, count, ) # --------------------------------------------------------------------------- # Multi-task upload (MTL) # --------------------------------------------------------------------------- def upload_mtan_model( model_dir: str, variant: str, desync: int, repo_id: Optional[str] = None, hf_prefix_override: Optional[str] = None, ) -> None: """ Upload a trained MTL multi-task model directory to HuggingFace Hub. Uploads to both the unified V1 repo and the legacy MTAN repo (if no override). Args: model_dir: Local directory containing model artifacts. variant: Model variant name (e.g., 'lmic_tsbn_v7b'). desync: Desynchronization level (0, 50, or 100). repo_id: Override the default HuggingFace repo ID (skips V1 upload). """ hf_prefix = hf_prefix_override if hf_prefix_override else f"desync{desync}/{variant}" api = HfApi() # Collect all files to upload files: List[str] = [] for filename in ["model.h5", "results.json"]: if os.path.isfile(os.path.join(model_dir, filename)): files.append(filename) for byte_idx in range(16): filename = f"rank_curve_byte{byte_idx}.npy" if os.path.isfile(os.path.join(model_dir, filename)): files.append(filename) if not files: logger.error("No files found in %s to upload", model_dir) return # Always upload to V1 repo _ensure_v1_repo(api) count = _upload_files(api, HF_V1_REPO, model_dir, hf_prefix, files) logger.info( "V1 MTAN upload: variant=%s desync=%d -> %s (%d/%d files)", variant, desync, HF_V1_REPO, count, len(files), ) # Also upload to legacy repo legacy_repo = repo_id if repo_id else HF_MTAN_REPO try: api.create_repo( repo_id=legacy_repo, repo_type="model", exist_ok=True, private=False, ) except Exception as e: logger.warning("Could not create/verify repo %s: %s", legacy_repo, e) count = _upload_files(api, legacy_repo, model_dir, hf_prefix, files) logger.info( "Legacy MTAN upload: variant=%s desync=%d -> %s (%d/%d files)", variant, desync, legacy_repo, count, len(files), ) # --------------------------------------------------------------------------- # Download functions # --------------------------------------------------------------------------- def download_mtan_model( variant: str, desync: int, local_dir: str, repo_id: Optional[str] = None, ) -> str: """ Download a trained MTAN model from HuggingFace Hub. Tries V1 repo first, falls back to legacy repo. Args: variant: Model variant name (e.g., 'lmic_tsbn_v7b'). desync: Desynchronization level (0, 50, or 100). local_dir: Local directory to save downloaded files. repo_id: Override the default HuggingFace repo ID. Returns: Path to the downloaded model directory. """ if repo_id is None: repo_id = HF_V1_REPO hf_prefix = hf_prefix_override if hf_prefix_override else f"desync{desync}/{variant}" os.makedirs(local_dir, exist_ok=True) files_to_download = ["model.h5", "results.json"] for byte_idx in range(16): files_to_download.append(f"rank_curve_byte{byte_idx}.npy") for filename in files_to_download: hf_path = f"{hf_prefix}/{filename}" try: downloaded = hf_hub_download( repo_id=repo_id, filename=hf_path, local_dir=local_dir, ) logger.info("Downloaded %s -> %s", hf_path, downloaded) except Exception as e: logger.warning("Could not download %s: %s", hf_path, e) return local_dir def download_model( model_type: str, target_byte: int, desync: int, local_dir: str, repo_id: Optional[str] = None, ) -> str: """ Download a trained single-task model from HuggingFace Hub. Tries V1 repo first, falls back to legacy repo. Args: model_type: 'mlp' or 'cnn'. target_byte: Target byte index (0-15). desync: Desynchronization level (0, 50, or 100). local_dir: Local directory to save downloaded files. repo_id: Override the default HuggingFace repo ID. Returns: Path to the downloaded model directory. """ if repo_id is None: repo_id = HF_V1_REPO hf_prefix = _build_hf_path(model_type, target_byte, desync) os.makedirs(local_dir, exist_ok=True) for filename in ["model.h5", "results.json", "rank_curve.npy"]: hf_path = f"{hf_prefix}/{filename}" try: downloaded = hf_hub_download( repo_id=repo_id, filename=hf_path, local_dir=local_dir, ) logger.info("Downloaded %s -> %s", hf_path, downloaded) except Exception as e: logger.warning("Could not download %s: %s", hf_path, e) return local_dir # --------------------------------------------------------------------------- # Audit / count (V1 repo) # --------------------------------------------------------------------------- def audit_repository( model_type: str, repo_id: Optional[str] = None, ) -> Dict[str, Dict[int, bool]]: """ Audit a HuggingFace model repository to check which models exist. Returns: Nested dict: {desync_key: {byte_idx: has_model}}. Example: {"desync0": {0: True, 1: False, ...}, ...} """ if repo_id is None: repo_id = HF_V1_REPO api = HfApi() try: files = api.list_repo_files(repo_id=repo_id, repo_type="model") except Exception as e: logger.error("Could not list files in %s: %s", repo_id, e) return {} result = {} for desync in [0, 50, 100]: desync_key = f"desync{desync}" result[desync_key] = {} for byte_idx in range(16): hf_path = f"{_build_hf_path(model_type, byte_idx, desync)}/model.h5" result[desync_key][byte_idx] = hf_path in files return result def count_models( model_type: str, repo_id: Optional[str] = None, ) -> int: """Count the total number of models uploaded for a given type.""" audit = audit_repository(model_type, repo_id) return sum( 1 for desync_dict in audit.values() for exists in desync_dict.values() if exists )