Spaces:
Running
Running
File size: 1,473 Bytes
e22f65c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | """
Model registry for weather forecasting architectures.
Usage:
from models import create_model, MODEL_REGISTRY
model = create_model("cnn_baseline", n_input_channels=42)
"""
from .cnn_baseline import BaselineCNN
from .cnn_multi_frame import MultiFrameCNN
from .cnn_3d import CNN3D
from .vit import WeatherViT
from .resnet_baseline import ResNet18Baseline
from .convnext_baseline import ConvNeXtBaseline
MODEL_REGISTRY = {
"cnn_baseline": BaselineCNN,
"cnn_multi_frame": MultiFrameCNN,
"cnn_3d": CNN3D,
"vit": WeatherViT,
"resnet18": ResNet18Baseline,
"convnext_tiny": ConvNeXtBaseline,
}
# Default model-specific settings
MODEL_DEFAULTS = {
"cnn_baseline": {"n_frames": 1, "stack_mode": "channel"},
"cnn_multi_frame": {"n_frames": 4, "stack_mode": "channel"},
"cnn_3d": {"n_frames": 4, "stack_mode": "temporal"},
"vit": {"n_frames": 1, "stack_mode": "channel"},
"resnet18": {"n_frames": 1, "stack_mode": "channel"},
"convnext_tiny": {"n_frames": 1, "stack_mode": "channel"},
}
def create_model(name, **kwargs):
"""Instantiate a model by name with given kwargs."""
if name not in MODEL_REGISTRY:
raise ValueError(f"Unknown model: {name}. Available: {list(MODEL_REGISTRY.keys())}")
return MODEL_REGISTRY[name](**kwargs)
def get_model_defaults(name):
"""Return default n_frames and stack_mode for a model."""
return MODEL_DEFAULTS.get(name, {"n_frames": 1, "stack_mode": "channel"})
|