weather_predict / models /__init__.py
jeffliulab's picture
Initial deploy: real-time weather forecast demo
e22f65c
"""
Model registry for weather forecasting architectures.
Usage:
from models import create_model, MODEL_REGISTRY
model = create_model("cnn_baseline", n_input_channels=42)
"""
from .cnn_baseline import BaselineCNN
from .cnn_multi_frame import MultiFrameCNN
from .cnn_3d import CNN3D
from .vit import WeatherViT
from .resnet_baseline import ResNet18Baseline
from .convnext_baseline import ConvNeXtBaseline
MODEL_REGISTRY = {
"cnn_baseline": BaselineCNN,
"cnn_multi_frame": MultiFrameCNN,
"cnn_3d": CNN3D,
"vit": WeatherViT,
"resnet18": ResNet18Baseline,
"convnext_tiny": ConvNeXtBaseline,
}
# Default model-specific settings
MODEL_DEFAULTS = {
"cnn_baseline": {"n_frames": 1, "stack_mode": "channel"},
"cnn_multi_frame": {"n_frames": 4, "stack_mode": "channel"},
"cnn_3d": {"n_frames": 4, "stack_mode": "temporal"},
"vit": {"n_frames": 1, "stack_mode": "channel"},
"resnet18": {"n_frames": 1, "stack_mode": "channel"},
"convnext_tiny": {"n_frames": 1, "stack_mode": "channel"},
}
def create_model(name, **kwargs):
"""Instantiate a model by name with given kwargs."""
if name not in MODEL_REGISTRY:
raise ValueError(f"Unknown model: {name}. Available: {list(MODEL_REGISTRY.keys())}")
return MODEL_REGISTRY[name](**kwargs)
def get_model_defaults(name):
"""Return default n_frames and stack_mode for a model."""
return MODEL_DEFAULTS.get(name, {"n_frames": 1, "stack_mode": "channel"})