| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Public model registry for the early-access public release. |
| |
| External users choose `model_id` in {1..6}. This registry maps model_id to: |
| - the underlying architecture parameters (num_filters, kernel_size) |
| - the model receptive field R (in rounds / distance units) |
| |
| Receptive field convention matches `compare_receptive_field_with_window_data` |
| in `code/training/utils.py`: |
| R = 1 + sum_i (k_i - 1) for kernel sizes k_i (assumed odd, with same-padding) |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Dict, List |
|
|
|
|
| def compute_receptive_field(kernel_sizes: List[int]) -> int: |
| """Compute receptive field R from a list of kernel sizes.""" |
| if not kernel_sizes: |
| raise ValueError("kernel_sizes must be non-empty") |
| if any(not isinstance(k, int) for k in kernel_sizes): |
| raise ValueError(f"kernel_sizes must be ints, got: {kernel_sizes!r}") |
| if any(k <= 0 for k in kernel_sizes): |
| raise ValueError(f"kernel_sizes must be positive, got: {kernel_sizes!r}") |
| return 1 + sum(kernel_sizes) - len(kernel_sizes) |
|
|
|
|
| @dataclass(frozen=True) |
| class PublicModelSpec: |
| model_id: int |
| num_filters: List[int] |
| kernel_size: List[int] |
| receptive_field: int |
| model_version: str = "predecoder_memory_v1" |
|
|
|
|
| _MODEL_SPECS: Dict[int, PublicModelSpec] = { |
| 1: |
| PublicModelSpec( |
| model_id=1, |
| num_filters=[128, 128, 128, 4], |
| kernel_size=[3, 3, 3, 3], |
| receptive_field=compute_receptive_field([3, 3, 3, 3]), |
| ), |
| 2: |
| PublicModelSpec( |
| model_id=2, |
| num_filters=[256, 256, 256, 4], |
| kernel_size=[3, 3, 3, 3], |
| receptive_field=compute_receptive_field([3, 3, 3, 3]), |
| ), |
| 3: |
| PublicModelSpec( |
| model_id=3, |
| num_filters=[128, 128, 128, 4], |
| kernel_size=[5, 5, 5, 5], |
| receptive_field=compute_receptive_field([5, 5, 5, 5]), |
| ), |
| 4: |
| PublicModelSpec( |
| model_id=4, |
| num_filters=[128, 128, 128, 128, 128, 4], |
| kernel_size=[3, 3, 3, 3, 3, 3], |
| receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]), |
| ), |
| 5: |
| PublicModelSpec( |
| model_id=5, |
| num_filters=[256, 256, 256, 256, 256, 4], |
| kernel_size=[3, 3, 3, 3, 3, 3], |
| receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]), |
| ), |
| 6: |
| PublicModelSpec( |
| model_id=6, |
| num_filters=[96, 96, 96, 96, 96, 4], |
| kernel_size=[3, 3, 3, 3, 3, 3], |
| receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]), |
| model_version="predecoder_fasthyper_rf13_v1", |
| ), |
| } |
|
|
|
|
| def get_model_spec(model_id: int) -> PublicModelSpec: |
| """Return the public model spec for a given model_id (1..6).""" |
| try: |
| mid = int(model_id) |
| except Exception as e: |
| raise ValueError(f"model_id must be an int in [1..6], got: {model_id!r}") from e |
| if mid == 0: |
| raise ValueError("model_id=0 is not supported in the public release") |
| if mid not in _MODEL_SPECS: |
| raise ValueError(f"model_id must be in [1..6], got: {mid}") |
| return _MODEL_SPECS[mid] |
|
|