Section 43-45: dataset.py per-trace shift fix, orchestrator DELETE FK fix, queue rebuild, clean MLP/CNN resubmission
8c414b1 | """ | |
| Artifact Management | |
| =================== | |
| Handles uploading trained models and results to HuggingFace Hub, | |
| and downloading them for evaluation or transfer learning. | |
| Supports both single-task models (MLP, CNN) and multi-task models (MTL). | |
| The unified V1 repo (ascad-v1-models) uses the following structure: | |
| - Single-byte: desync{N}/{model_type}/byte{X}/ (model.h5, results.json, rank_curve.npy) | |
| - Multi-task: desync{N}/{variant}/ (model.h5, results.json, rank_curve_byte{0..15}.npy) | |
| """ | |
| import json | |
| import logging | |
| import os | |
| from typing import Dict, List, Optional | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from .constants import HF_MLP_REPO, HF_CNN_REPO, HF_MTAN_REPO, HF_V1_REPO | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Repository mapping (legacy) | |
| # --------------------------------------------------------------------------- | |
| def _get_legacy_repo_id(model_type: str) -> str: | |
| """Map model type to legacy HuggingFace repository ID.""" | |
| repos = {"mlp": HF_MLP_REPO, "cnn": HF_CNN_REPO, "mtan": HF_MTAN_REPO} | |
| if model_type not in repos: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| return repos[model_type] | |
| def _build_hf_path(model_type: str, target_byte: int, desync: int) -> str: | |
| """Build the HuggingFace path prefix for a single-task model.""" | |
| return f"desync{desync}/{model_type}/byte{target_byte}" | |
| # --------------------------------------------------------------------------- | |
| # Unified V1 upload helpers | |
| # --------------------------------------------------------------------------- | |
| def _ensure_v1_repo(api: HfApi) -> None: | |
| """Ensure the V1 repo exists.""" | |
| try: | |
| api.create_repo( | |
| repo_id=HF_V1_REPO, | |
| repo_type="model", | |
| exist_ok=True, | |
| private=False, | |
| ) | |
| except Exception as e: | |
| logger.warning("Could not create/verify repo %s: %s", HF_V1_REPO, e) | |
| def _upload_files( | |
| api: HfApi, | |
| repo_id: str, | |
| local_dir: str, | |
| hf_prefix: str, | |
| filenames: List[str], | |
| ) -> int: | |
| """Upload a list of files to a HuggingFace repo. Returns count of uploaded files.""" | |
| uploaded = 0 | |
| for filename in filenames: | |
| local_path = os.path.join(local_dir, filename) | |
| if not os.path.isfile(local_path): | |
| logger.warning("File not found, skipping: %s", local_path) | |
| continue | |
| hf_path = f"{hf_prefix}/{filename}" | |
| try: | |
| logger.info("Uploading %s -> %s/%s", local_path, repo_id, hf_path) | |
| api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=hf_path, | |
| repo_id=repo_id, | |
| repo_type="model", | |
| ) | |
| uploaded += 1 | |
| except Exception as e: | |
| logger.error("Failed to upload %s: %s", filename, e) | |
| return uploaded | |
| # --------------------------------------------------------------------------- | |
| # Single-task upload (MLP, CNN) | |
| # --------------------------------------------------------------------------- | |
| def upload_model( | |
| model_dir: str, | |
| model_type: str, | |
| target_byte: int, | |
| desync: int, | |
| repo_id: Optional[str] = None, | |
| failed: bool = False, | |
| ) -> None: | |
| """ | |
| Upload a trained single-task model directory to HuggingFace Hub. | |
| Uploads to both the unified V1 repo and the legacy repo (if no override). | |
| Args: | |
| model_dir: Local directory containing model.h5, results.json, rank_curve.npy. | |
| model_type: 'mlp' or 'cnn'. | |
| target_byte: Target byte index (0-15). | |
| desync: Desynchronization level (0, 50, or 100). | |
| repo_id: Override the default HuggingFace repo ID (skips V1 upload). | |
| """ | |
| hf_prefix = _build_hf_path(model_type, target_byte, desync) | |
| if failed: | |
| hf_prefix = f"failed/{hf_prefix}" | |
| files = ["model.h5", "results.json", "rank_curve.npy"] | |
| api = HfApi() | |
| # Always upload to V1 repo | |
| _ensure_v1_repo(api) | |
| count = _upload_files(api, HF_V1_REPO, model_dir, hf_prefix, files) | |
| logger.info( | |
| "V1 upload complete: %s byte=%d desync=%d -> %s (%d files)", | |
| model_type, target_byte, desync, HF_V1_REPO, count, | |
| ) | |
| # Also upload to legacy repo if no override | |
| if repo_id is None: | |
| legacy_repo = _get_legacy_repo_id(model_type) | |
| else: | |
| legacy_repo = repo_id | |
| count = _upload_files(api, legacy_repo, model_dir, hf_prefix, files) | |
| logger.info( | |
| "Legacy upload complete: %s byte=%d desync=%d -> %s (%d files)", | |
| model_type, target_byte, desync, legacy_repo, count, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Multi-task upload (MTL) | |
| # --------------------------------------------------------------------------- | |
| def upload_mtan_model( | |
| model_dir: str, | |
| variant: str, | |
| desync: int, | |
| repo_id: Optional[str] = None, | |
| hf_prefix_override: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| Upload a trained MTL multi-task model directory to HuggingFace Hub. | |
| Uploads to both the unified V1 repo and the legacy MTAN repo (if no override). | |
| Args: | |
| model_dir: Local directory containing model artifacts. | |
| variant: Model variant name (e.g., 'lmic_tsbn_v7b'). | |
| desync: Desynchronization level (0, 50, or 100). | |
| repo_id: Override the default HuggingFace repo ID (skips V1 upload). | |
| """ | |
| hf_prefix = hf_prefix_override if hf_prefix_override else f"desync{desync}/{variant}" | |
| api = HfApi() | |
| # Collect all files to upload | |
| files: List[str] = [] | |
| for filename in ["model.h5", "results.json"]: | |
| if os.path.isfile(os.path.join(model_dir, filename)): | |
| files.append(filename) | |
| for byte_idx in range(16): | |
| filename = f"rank_curve_byte{byte_idx}.npy" | |
| if os.path.isfile(os.path.join(model_dir, filename)): | |
| files.append(filename) | |
| if not files: | |
| logger.error("No files found in %s to upload", model_dir) | |
| return | |
| # Always upload to V1 repo | |
| _ensure_v1_repo(api) | |
| count = _upload_files(api, HF_V1_REPO, model_dir, hf_prefix, files) | |
| logger.info( | |
| "V1 MTAN upload: variant=%s desync=%d -> %s (%d/%d files)", | |
| variant, desync, HF_V1_REPO, count, len(files), | |
| ) | |
| # Also upload to legacy repo | |
| legacy_repo = repo_id if repo_id else HF_MTAN_REPO | |
| try: | |
| api.create_repo( | |
| repo_id=legacy_repo, | |
| repo_type="model", | |
| exist_ok=True, | |
| private=False, | |
| ) | |
| except Exception as e: | |
| logger.warning("Could not create/verify repo %s: %s", legacy_repo, e) | |
| count = _upload_files(api, legacy_repo, model_dir, hf_prefix, files) | |
| logger.info( | |
| "Legacy MTAN upload: variant=%s desync=%d -> %s (%d/%d files)", | |
| variant, desync, legacy_repo, count, len(files), | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Download functions | |
| # --------------------------------------------------------------------------- | |
| def download_mtan_model( | |
| variant: str, | |
| desync: int, | |
| local_dir: str, | |
| repo_id: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Download a trained MTAN model from HuggingFace Hub. | |
| Tries V1 repo first, falls back to legacy repo. | |
| Args: | |
| variant: Model variant name (e.g., 'lmic_tsbn_v7b'). | |
| desync: Desynchronization level (0, 50, or 100). | |
| local_dir: Local directory to save downloaded files. | |
| repo_id: Override the default HuggingFace repo ID. | |
| Returns: | |
| Path to the downloaded model directory. | |
| """ | |
| if repo_id is None: | |
| repo_id = HF_V1_REPO | |
| hf_prefix = hf_prefix_override if hf_prefix_override else f"desync{desync}/{variant}" | |
| os.makedirs(local_dir, exist_ok=True) | |
| files_to_download = ["model.h5", "results.json"] | |
| for byte_idx in range(16): | |
| files_to_download.append(f"rank_curve_byte{byte_idx}.npy") | |
| for filename in files_to_download: | |
| hf_path = f"{hf_prefix}/{filename}" | |
| try: | |
| downloaded = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=hf_path, | |
| local_dir=local_dir, | |
| ) | |
| logger.info("Downloaded %s -> %s", hf_path, downloaded) | |
| except Exception as e: | |
| logger.warning("Could not download %s: %s", hf_path, e) | |
| return local_dir | |
| def download_model( | |
| model_type: str, | |
| target_byte: int, | |
| desync: int, | |
| local_dir: str, | |
| repo_id: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Download a trained single-task model from HuggingFace Hub. | |
| Tries V1 repo first, falls back to legacy repo. | |
| Args: | |
| model_type: 'mlp' or 'cnn'. | |
| target_byte: Target byte index (0-15). | |
| desync: Desynchronization level (0, 50, or 100). | |
| local_dir: Local directory to save downloaded files. | |
| repo_id: Override the default HuggingFace repo ID. | |
| Returns: | |
| Path to the downloaded model directory. | |
| """ | |
| if repo_id is None: | |
| repo_id = HF_V1_REPO | |
| hf_prefix = _build_hf_path(model_type, target_byte, desync) | |
| os.makedirs(local_dir, exist_ok=True) | |
| for filename in ["model.h5", "results.json", "rank_curve.npy"]: | |
| hf_path = f"{hf_prefix}/{filename}" | |
| try: | |
| downloaded = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=hf_path, | |
| local_dir=local_dir, | |
| ) | |
| logger.info("Downloaded %s -> %s", hf_path, downloaded) | |
| except Exception as e: | |
| logger.warning("Could not download %s: %s", hf_path, e) | |
| return local_dir | |
| # --------------------------------------------------------------------------- | |
| # Audit / count (V1 repo) | |
| # --------------------------------------------------------------------------- | |
| def audit_repository( | |
| model_type: str, | |
| repo_id: Optional[str] = None, | |
| ) -> Dict[str, Dict[int, bool]]: | |
| """ | |
| Audit a HuggingFace model repository to check which models exist. | |
| Returns: | |
| Nested dict: {desync_key: {byte_idx: has_model}}. | |
| Example: {"desync0": {0: True, 1: False, ...}, ...} | |
| """ | |
| if repo_id is None: | |
| repo_id = HF_V1_REPO | |
| api = HfApi() | |
| try: | |
| files = api.list_repo_files(repo_id=repo_id, repo_type="model") | |
| except Exception as e: | |
| logger.error("Could not list files in %s: %s", repo_id, e) | |
| return {} | |
| result = {} | |
| for desync in [0, 50, 100]: | |
| desync_key = f"desync{desync}" | |
| result[desync_key] = {} | |
| for byte_idx in range(16): | |
| hf_path = f"{_build_hf_path(model_type, byte_idx, desync)}/model.h5" | |
| result[desync_key][byte_idx] = hf_path in files | |
| return result | |
| def count_models( | |
| model_type: str, | |
| repo_id: Optional[str] = None, | |
| ) -> int: | |
| """Count the total number of models uploaded for a given type.""" | |
| audit = audit_repository(model_type, repo_id) | |
| return sum( | |
| 1 for desync_dict in audit.values() | |
| for exists in desync_dict.values() | |
| if exists | |
| ) | |