ascad-training-pipeline / src /artifacts.py
lemousehunter's picture
Section 43-45: dataset.py per-trace shift fix, orchestrator DELETE FK fix, queue rebuild, clean MLP/CNN resubmission
8c414b1
"""
Artifact Management
===================
Handles uploading trained models and results to HuggingFace Hub,
and downloading them for evaluation or transfer learning.
Supports both single-task models (MLP, CNN) and multi-task models (MTL).
The unified V1 repo (ascad-v1-models) uses the following structure:
- Single-byte: desync{N}/{model_type}/byte{X}/ (model.h5, results.json, rank_curve.npy)
- Multi-task: desync{N}/{variant}/ (model.h5, results.json, rank_curve_byte{0..15}.npy)
"""
import json
import logging
import os
from typing import Dict, List, Optional
from huggingface_hub import HfApi, hf_hub_download
from .constants import HF_MLP_REPO, HF_CNN_REPO, HF_MTAN_REPO, HF_V1_REPO
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Repository mapping (legacy)
# ---------------------------------------------------------------------------
def _get_legacy_repo_id(model_type: str) -> str:
"""Map model type to legacy HuggingFace repository ID."""
repos = {"mlp": HF_MLP_REPO, "cnn": HF_CNN_REPO, "mtan": HF_MTAN_REPO}
if model_type not in repos:
raise ValueError(f"Unknown model type: {model_type}")
return repos[model_type]
def _build_hf_path(model_type: str, target_byte: int, desync: int) -> str:
"""Build the HuggingFace path prefix for a single-task model."""
return f"desync{desync}/{model_type}/byte{target_byte}"
# ---------------------------------------------------------------------------
# Unified V1 upload helpers
# ---------------------------------------------------------------------------
def _ensure_v1_repo(api: HfApi) -> None:
"""Ensure the V1 repo exists."""
try:
api.create_repo(
repo_id=HF_V1_REPO,
repo_type="model",
exist_ok=True,
private=False,
)
except Exception as e:
logger.warning("Could not create/verify repo %s: %s", HF_V1_REPO, e)
def _upload_files(
api: HfApi,
repo_id: str,
local_dir: str,
hf_prefix: str,
filenames: List[str],
) -> int:
"""Upload a list of files to a HuggingFace repo. Returns count of uploaded files."""
uploaded = 0
for filename in filenames:
local_path = os.path.join(local_dir, filename)
if not os.path.isfile(local_path):
logger.warning("File not found, skipping: %s", local_path)
continue
hf_path = f"{hf_prefix}/{filename}"
try:
logger.info("Uploading %s -> %s/%s", local_path, repo_id, hf_path)
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=hf_path,
repo_id=repo_id,
repo_type="model",
)
uploaded += 1
except Exception as e:
logger.error("Failed to upload %s: %s", filename, e)
return uploaded
# ---------------------------------------------------------------------------
# Single-task upload (MLP, CNN)
# ---------------------------------------------------------------------------
def upload_model(
model_dir: str,
model_type: str,
target_byte: int,
desync: int,
repo_id: Optional[str] = None,
failed: bool = False,
) -> None:
"""
Upload a trained single-task model directory to HuggingFace Hub.
Uploads to both the unified V1 repo and the legacy repo (if no override).
Args:
model_dir: Local directory containing model.h5, results.json, rank_curve.npy.
model_type: 'mlp' or 'cnn'.
target_byte: Target byte index (0-15).
desync: Desynchronization level (0, 50, or 100).
repo_id: Override the default HuggingFace repo ID (skips V1 upload).
"""
hf_prefix = _build_hf_path(model_type, target_byte, desync)
if failed:
hf_prefix = f"failed/{hf_prefix}"
files = ["model.h5", "results.json", "rank_curve.npy"]
api = HfApi()
# Always upload to V1 repo
_ensure_v1_repo(api)
count = _upload_files(api, HF_V1_REPO, model_dir, hf_prefix, files)
logger.info(
"V1 upload complete: %s byte=%d desync=%d -> %s (%d files)",
model_type, target_byte, desync, HF_V1_REPO, count,
)
# Also upload to legacy repo if no override
if repo_id is None:
legacy_repo = _get_legacy_repo_id(model_type)
else:
legacy_repo = repo_id
count = _upload_files(api, legacy_repo, model_dir, hf_prefix, files)
logger.info(
"Legacy upload complete: %s byte=%d desync=%d -> %s (%d files)",
model_type, target_byte, desync, legacy_repo, count,
)
# ---------------------------------------------------------------------------
# Multi-task upload (MTL)
# ---------------------------------------------------------------------------
def upload_mtan_model(
model_dir: str,
variant: str,
desync: int,
repo_id: Optional[str] = None,
hf_prefix_override: Optional[str] = None,
) -> None:
"""
Upload a trained MTL multi-task model directory to HuggingFace Hub.
Uploads to both the unified V1 repo and the legacy MTAN repo (if no override).
Args:
model_dir: Local directory containing model artifacts.
variant: Model variant name (e.g., 'lmic_tsbn_v7b').
desync: Desynchronization level (0, 50, or 100).
repo_id: Override the default HuggingFace repo ID (skips V1 upload).
"""
hf_prefix = hf_prefix_override if hf_prefix_override else f"desync{desync}/{variant}"
api = HfApi()
# Collect all files to upload
files: List[str] = []
for filename in ["model.h5", "results.json"]:
if os.path.isfile(os.path.join(model_dir, filename)):
files.append(filename)
for byte_idx in range(16):
filename = f"rank_curve_byte{byte_idx}.npy"
if os.path.isfile(os.path.join(model_dir, filename)):
files.append(filename)
if not files:
logger.error("No files found in %s to upload", model_dir)
return
# Always upload to V1 repo
_ensure_v1_repo(api)
count = _upload_files(api, HF_V1_REPO, model_dir, hf_prefix, files)
logger.info(
"V1 MTAN upload: variant=%s desync=%d -> %s (%d/%d files)",
variant, desync, HF_V1_REPO, count, len(files),
)
# Also upload to legacy repo
legacy_repo = repo_id if repo_id else HF_MTAN_REPO
try:
api.create_repo(
repo_id=legacy_repo,
repo_type="model",
exist_ok=True,
private=False,
)
except Exception as e:
logger.warning("Could not create/verify repo %s: %s", legacy_repo, e)
count = _upload_files(api, legacy_repo, model_dir, hf_prefix, files)
logger.info(
"Legacy MTAN upload: variant=%s desync=%d -> %s (%d/%d files)",
variant, desync, legacy_repo, count, len(files),
)
# ---------------------------------------------------------------------------
# Download functions
# ---------------------------------------------------------------------------
def download_mtan_model(
variant: str,
desync: int,
local_dir: str,
repo_id: Optional[str] = None,
) -> str:
"""
Download a trained MTAN model from HuggingFace Hub.
Tries V1 repo first, falls back to legacy repo.
Args:
variant: Model variant name (e.g., 'lmic_tsbn_v7b').
desync: Desynchronization level (0, 50, or 100).
local_dir: Local directory to save downloaded files.
repo_id: Override the default HuggingFace repo ID.
Returns:
Path to the downloaded model directory.
"""
if repo_id is None:
repo_id = HF_V1_REPO
hf_prefix = hf_prefix_override if hf_prefix_override else f"desync{desync}/{variant}"
os.makedirs(local_dir, exist_ok=True)
files_to_download = ["model.h5", "results.json"]
for byte_idx in range(16):
files_to_download.append(f"rank_curve_byte{byte_idx}.npy")
for filename in files_to_download:
hf_path = f"{hf_prefix}/{filename}"
try:
downloaded = hf_hub_download(
repo_id=repo_id,
filename=hf_path,
local_dir=local_dir,
)
logger.info("Downloaded %s -> %s", hf_path, downloaded)
except Exception as e:
logger.warning("Could not download %s: %s", hf_path, e)
return local_dir
def download_model(
model_type: str,
target_byte: int,
desync: int,
local_dir: str,
repo_id: Optional[str] = None,
) -> str:
"""
Download a trained single-task model from HuggingFace Hub.
Tries V1 repo first, falls back to legacy repo.
Args:
model_type: 'mlp' or 'cnn'.
target_byte: Target byte index (0-15).
desync: Desynchronization level (0, 50, or 100).
local_dir: Local directory to save downloaded files.
repo_id: Override the default HuggingFace repo ID.
Returns:
Path to the downloaded model directory.
"""
if repo_id is None:
repo_id = HF_V1_REPO
hf_prefix = _build_hf_path(model_type, target_byte, desync)
os.makedirs(local_dir, exist_ok=True)
for filename in ["model.h5", "results.json", "rank_curve.npy"]:
hf_path = f"{hf_prefix}/{filename}"
try:
downloaded = hf_hub_download(
repo_id=repo_id,
filename=hf_path,
local_dir=local_dir,
)
logger.info("Downloaded %s -> %s", hf_path, downloaded)
except Exception as e:
logger.warning("Could not download %s: %s", hf_path, e)
return local_dir
# ---------------------------------------------------------------------------
# Audit / count (V1 repo)
# ---------------------------------------------------------------------------
def audit_repository(
model_type: str,
repo_id: Optional[str] = None,
) -> Dict[str, Dict[int, bool]]:
"""
Audit a HuggingFace model repository to check which models exist.
Returns:
Nested dict: {desync_key: {byte_idx: has_model}}.
Example: {"desync0": {0: True, 1: False, ...}, ...}
"""
if repo_id is None:
repo_id = HF_V1_REPO
api = HfApi()
try:
files = api.list_repo_files(repo_id=repo_id, repo_type="model")
except Exception as e:
logger.error("Could not list files in %s: %s", repo_id, e)
return {}
result = {}
for desync in [0, 50, 100]:
desync_key = f"desync{desync}"
result[desync_key] = {}
for byte_idx in range(16):
hf_path = f"{_build_hf_path(model_type, byte_idx, desync)}/model.h5"
result[desync_key][byte_idx] = hf_path in files
return result
def count_models(
model_type: str,
repo_id: Optional[str] = None,
) -> int:
"""Count the total number of models uploaded for a given type."""
audit = audit_repository(model_type, repo_id)
return sum(
1 for desync_dict in audit.values()
for exists in desync_dict.values()
if exists
)