Spaces:
Runtime error
Runtime error
| from core.device import FRAMEWORK | |
| import torch.nn as nn | |
| import torch | |
| import inspect | |
| from typing import Callable, Dict, Any, Optional, Type, List | |
| # Must stay in sync with BENCHMARKS_2D in autorun.py. | |
| _2D_BENCHMARKS = frozenset({ | |
| "darcy_2d", "ns_2d", "swe_2d", "allen_cahn_2d", "ns_hre_2d", "mhd_2d", | |
| "elasticity_2d", "wavebench_2d", "pdebench_2d", "multiphysics_2d", | |
| "radiative_2d", "poisson_2d", "ellipse_2d" | |
| }) | |
| # ββ Model Registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ModelRegistry: | |
| def __init__(self): | |
| self._registry: Dict[str, Callable] = {} | |
| self._lazy_imports: Dict[str, tuple[str, str]] = {} | |
| def register_lazy(self, name: str, module_name: str, class_name: str) -> None: | |
| self._lazy_imports[name] = (module_name, class_name) | |
| def register(self, name: str) -> Callable: | |
| def _decorator(fn: Callable) -> Callable: | |
| self._registry[name] = fn | |
| return fn | |
| return _decorator | |
| def build(self, name: str, benchmark: str = "", **kwargs) -> nn.Module: | |
| if name not in self._registry and name in self._lazy_imports: | |
| mod_name, cls_name = self._lazy_imports[name] | |
| import importlib | |
| module = importlib.import_module(f"models.{mod_name}") | |
| cls = getattr(module, cls_name) | |
| self.register_class(name, cls) | |
| if name not in self._registry: | |
| available = ", ".join(sorted(set(self._registry) | set(self._lazy_imports))) | |
| raise ValueError(f"Unknown model {name!r}. Available: {available}") | |
| # Map n_modes -> (n_modes1, n_modes2) for 2D models | |
| if (name.endswith("2D") or name.endswith("2d")) and "n_modes" in kwargs: | |
| m = kwargs.pop("n_modes") | |
| kwargs["n_modes1"] = m | |
| kwargs["n_modes2"] = m | |
| return self._registry[name](**kwargs) | |
| def register_class(self, name: str, cls: Type[nn.Module], **fixed_kwargs) -> None: | |
| def _factory(**kwargs): | |
| kwargs.update(fixed_kwargs) | |
| sig = inspect.signature(cls.__init__) | |
| valid_args = { | |
| k: v for k, v in kwargs.items() | |
| if k in sig.parameters or any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()) | |
| } | |
| return cls(**valid_args) | |
| _factory.__name__ = f"factory_{name}" | |
| self._registry[name] = _factory | |
| def available(self): | |
| return sorted(set(self._registry) | set(self._lazy_imports)) | |
| MODEL_REGISTRY = ModelRegistry() | |
| # ββ Benchmark Registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class BenchmarkRegistry: | |
| def __init__(self): | |
| self._loaders: Dict[str, Callable] = {} | |
| self._evals: Dict[str, Callable] = {} | |
| self._sota: Dict[str, float] = {} | |
| self._desc: Dict[str, str] = {} | |
| def register(self, name: str, make_loader: Callable, evaluate: Callable, sota: Optional[float] = None, description: str = "") -> None: | |
| self._loaders[name] = make_loader | |
| self._evals[name] = evaluate | |
| self._sota[name] = sota | |
| self._desc[name] = description | |
| def make_loader(self, name: str, split: str, batch_size: int, **kwargs): | |
| if name not in self._loaders: raise ValueError(f"Unknown benchmark {name!r}") | |
| return self._loaders[name](name, split, batch_size, **kwargs) | |
| def evaluate(self, name: str, model_fn: Callable) -> float: | |
| if name not in self._evals: raise ValueError(f"Unknown benchmark {name!r}") | |
| return self._evals[name](name, model_fn) | |
| BENCHMARK_REGISTRY = BenchmarkRegistry() | |
| # ββ Populate registries from existing codebase ββββββββββββββββββββββββββββββββ | |
| def _register_defaults(): | |
| from data.prepare import make_dataloader, evaluate_l2_rel | |
| from data.benchmarks_ext import EXT_BENCHMARKS, EXT_SOTA, make_ext_dataloader, evaluate_l2_rel_ext | |
| # ββ Model Registrations (Lazy) βββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_REGISTRY.register_lazy("TestNet", f"testnet_{FRAMEWORK.lower()}", "TestNet") | |
| MODEL_REGISTRY.register_lazy("DualModelTest", f"dualmodeltest_{FRAMEWORK.lower()}", "DualModelTest") | |
| MODEL_REGISTRY.register_lazy("MambaFNO", "mambafno", "MambaFNO") | |
| MODEL_REGISTRY.register_lazy("FNO", "fno", "FNO1d") | |
| MODEL_REGISTRY.register_lazy("RFNO", "fno", "RFNO1d") | |
| MODEL_REGISTRY.register_lazy("AFNO", "afno", "AFNO1d") | |
| MODEL_REGISTRY.register_lazy("FFNO", "afno", "FFNO1d") | |
| MODEL_REGISTRY.register_lazy("UNO", "fno", "UNO1d") | |
| MODEL_REGISTRY.register_lazy("WNO", "wno", "WNO1d") | |
| MODEL_REGISTRY.register_lazy("KAN_FNO", "kan", "KAN_FNO") | |
| MODEL_REGISTRY.register_lazy("cPIKAN_FNO", "chebyshev_kan", "cPIKAN_FNO") | |
| MODEL_REGISTRY.register_lazy("DeepONet", "deeponet", "DeepONet") | |
| MODEL_REGISTRY.register_lazy("PODDeepONet", "deeponet", "PODDeepONet") | |
| MODEL_REGISTRY.register_lazy("S4NO", "s4d", "S4NO1d") | |
| MODEL_REGISTRY.register_lazy("SSNO", "ssno", "SSNO1d") | |
| MODEL_REGISTRY.register_lazy("UNO2d", "fno", "UNO2d") | |
| MODEL_REGISTRY.register_lazy("WNO2d", "wno", "WNO2d") | |
| MODEL_REGISTRY.register_lazy("GNOT", "gnot", "GNOT1d") | |
| MODEL_REGISTRY.register_lazy("GNOT2d", "gnot", "GNOT2d") | |
| MODEL_REGISTRY.register_lazy("MambaNO", "mamba_no", "MambaNO1d") | |
| MODEL_REGISTRY.register_lazy("PINO", "pinn", "PINO1d") | |
| MODEL_REGISTRY.register_lazy("PINN", "pinn", "PINN") | |
| MODEL_REGISTRY.register_lazy("ModalPINN", "pinn", "ModalPINN") | |
| MODEL_REGISTRY.register_lazy("FNO2D", "fno", "FNO2d") | |
| MODEL_REGISTRY.register_lazy("TFNO", "tfno", "TFNO1d") | |
| MODEL_REGISTRY.register_lazy("TFNO2D", "tfno", "TFNO2d") | |
| MODEL_REGISTRY.register_lazy("Transolver", "transolver", "Transolver1d") | |
| MODEL_REGISTRY.register_lazy("Transolver2D", "transolver", "Transolver2d") | |
| MODEL_REGISTRY.register_lazy("TimeDeepONet", "time_deeponet", "TimeDeepONet1d") | |
| MODEL_REGISTRY.register_lazy("HNN", "hnn", "HamiltonianNO1d") | |
| MODEL_REGISTRY.register_lazy("NeuralODE", "neural_ode", "NeuralODE1d") | |
| MODEL_REGISTRY.register_lazy("PACMANN", "pacmann", "PACMANN") | |
| MODEL_REGISTRY.register_lazy("VSMNO2D", "vsmno", "VSMNO2d") | |
| MODEL_REGISTRY.register_lazy("MemNO", "mem_no", "MemNO1d") | |
| MODEL_REGISTRY.register_lazy("SARModel2d", "sar", "SARModel2d") | |
| # ββ Benchmarks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BENCHMARK_REGISTRY.register("burgers_1d", make_dataloader, evaluate_l2_rel, 0.0149, "1D Burgers") | |
| for bm in EXT_BENCHMARKS: | |
| BENCHMARK_REGISTRY.register(bm, make_ext_dataloader, evaluate_l2_rel_ext, EXT_SOTA.get(bm)) | |
| # ββ Simulation Benchmarks βββββββββββββββββββββββββββββββββββββββββββββββ | |
| from data.simulations import SIM_BENCHMARKS, SIM_SOTA, SIM_METADATA, make_sim_dataloader, evaluate_l2_rel_sim | |
| for bm in SIM_BENCHMARKS: | |
| BENCHMARK_REGISTRY.register(bm, make_sim_dataloader, evaluate_l2_rel_sim, SIM_SOTA.get(bm)) | |
| _register_defaults() | |