# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Public model registry for the early-access public release. External users choose `model_id` in {1..6}. This registry maps model_id to: - the underlying architecture parameters (num_filters, kernel_size) - the model receptive field R (in rounds / distance units) Receptive field convention matches `compare_receptive_field_with_window_data` in `code/training/utils.py`: R = 1 + sum_i (k_i - 1) for kernel sizes k_i (assumed odd, with same-padding) """ from __future__ import annotations from dataclasses import dataclass from typing import Dict, List def compute_receptive_field(kernel_sizes: List[int]) -> int: """Compute receptive field R from a list of kernel sizes.""" if not kernel_sizes: raise ValueError("kernel_sizes must be non-empty") if any(not isinstance(k, int) for k in kernel_sizes): raise ValueError(f"kernel_sizes must be ints, got: {kernel_sizes!r}") if any(k <= 0 for k in kernel_sizes): raise ValueError(f"kernel_sizes must be positive, got: {kernel_sizes!r}") return 1 + sum(kernel_sizes) - len(kernel_sizes) @dataclass(frozen=True) class PublicModelSpec: model_id: int num_filters: List[int] kernel_size: List[int] receptive_field: int model_version: str = "predecoder_memory_v1" _MODEL_SPECS: Dict[int, PublicModelSpec] = { 1: PublicModelSpec( model_id=1, num_filters=[128, 128, 128, 4], kernel_size=[3, 3, 3, 3], receptive_field=compute_receptive_field([3, 3, 3, 3]), ), 2: PublicModelSpec( model_id=2, num_filters=[256, 256, 256, 4], kernel_size=[3, 3, 3, 3], receptive_field=compute_receptive_field([3, 3, 3, 3]), ), 3: PublicModelSpec( model_id=3, num_filters=[128, 128, 128, 4], kernel_size=[5, 5, 5, 5], receptive_field=compute_receptive_field([5, 5, 5, 5]), ), 4: PublicModelSpec( model_id=4, num_filters=[128, 128, 128, 128, 128, 4], kernel_size=[3, 3, 3, 3, 3, 3], receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]), ), 5: PublicModelSpec( model_id=5, num_filters=[256, 256, 256, 256, 256, 4], kernel_size=[3, 3, 3, 3, 3, 3], receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]), ), 6: PublicModelSpec( model_id=6, num_filters=[96, 96, 96, 96, 96, 4], kernel_size=[3, 3, 3, 3, 3, 3], receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]), model_version="predecoder_fasthyper_rf13_v1", ), } def get_model_spec(model_id: int) -> PublicModelSpec: """Return the public model spec for a given model_id (1..6).""" try: mid = int(model_id) except Exception as e: raise ValueError(f"model_id must be an int in [1..6], got: {model_id!r}") from e if mid == 0: raise ValueError("model_id=0 is not supported in the public release") if mid not in _MODEL_SPECS: raise ValueError(f"model_id must be in [1..6], got: {mid}") return _MODEL_SPECS[mid]