| from .mlp import MLP | |
| from .siren import SIREN | |
| from .wire import WIRE | |
| from .activations import get_activation, list_activations | |
| from flax import serialization | |
| import os | |
| import yaml | |
| import jax | |
| import jax.numpy as jnp | |
| from ml_collections import ConfigDict | |
| model_key_dict = { | |
| "MLP": MLP, | |
| "SIREN": SIREN, | |
| "WIRE": WIRE | |
| } | |
| def make_model(config): | |
| """ | |
| Create and configure a flax neural network nn.Module based on configuration. | |
| Args: | |
| config: Model configuration containing: | |
| - model_name: Type of model (MLP, SIREN, WIRE, etc.) | |
| - output_dim: Number of output dimensions | |
| - hidden_dim: Hidden layer dimensions | |
| - num_layers: Number of layers | |
| - activation: Activation function name | |
| - extra_model_args: Additional model-specific arguments | |
| Returns: | |
| model (nn.Module): Configured flax nn.Module instance ready for training | |
| Note: | |
| Handles special case for WIRE and SIREN models which don't accept | |
| activation functions as an argument. | |
| """ | |
| model = get_model(config.model_name) | |
| if config.extra_model_args is not None: | |
| if config.model_name == "WIRE" or config.model_name == "SIREN": | |
| model = model(output_dim=config.output_dim, | |
| hidden_dim=config.hidden_dim, | |
| num_layers=config.num_layers, | |
| **config.extra_model_args) | |
| else: | |
| model = model(output_dim=config.output_dim, | |
| hidden_dim=config.hidden_dim, | |
| num_layers=config.num_layers, | |
| act=get_activation(config.activation), | |
| **config.extra_model_args) | |
| else: | |
| model = model(output_dim=config.output_dim, | |
| hidden_dim=config.hidden_dim, | |
| num_layers=config.num_layers, | |
| act=get_activation(config.activation), | |
| ) | |
| return model | |
| def load_metric_from_model(model_dir): | |
| """ | |
| Load the model state from a given directory. | |
| If the model has output dimension of 10, meaning | |
| it was trained only on the symmetric part of the metric, | |
| it reconstructs the full metric tensor. | |
| Args: | |
| model_dir (str): Directory containing the model state file. | |
| Returns: | |
| callable: The metric tensor function from the model. | |
| """ | |
| with open(os.path.join(model_dir, "params.msgpack"), "rb") as f: | |
| params = serialization.msgpack_restore(f.read()) | |
| with open(os.path.join(model_dir, "architecture.yml"), "r") as f: | |
| config_model = yaml.load(f, Loader=yaml.FullLoader) | |
| config_model = ConfigDict(config_model) | |
| model = make_model(config_model.architecture) | |
| if config_model.architecture.output_dim == 16: | |
| return lambda coords: model.apply(params, coords).reshape(4, 4) | |
| elif config_model.architecture.output_dim == 10: | |
| return lambda coords: reconstruct_full_metric(model.apply(params, coords)).reshape(4, 4) | |
| def reconstruct_full_metric(metric_sym: jax.Array, n : int) -> jax.Array: | |
| """returns the fully reconstructed (n, n) metric tensor from the symmetry reduced metric""" | |
| i, j = jnp.triu_indices(n, k=0) | |
| matrix = jnp.zeros((n, n)) | |
| matrix = matrix.at[i, j].set(metric_sym) | |
| matrix = matrix.at[j, i].set(metric_sym) | |
| return matrix | |
| def get_model(model_name : str): | |
| """ | |
| Get the model class by name. | |
| Args: | |
| model_name (str): Name of the model. | |
| Returns: | |
| nn.Module: The model class. | |
| """ | |
| if model_name not in model_key_dict: | |
| raise ValueError(f"Model `{model_name}` is not supported. Supported models are: {list(model_key_dict.keys())}") | |
| return model_key_dict[model_name] | |
| def create_model_configs(): | |
| """ | |
| Create a dictionary of model configurations. | |
| Returns: | |
| dict: A dictionary of model configurations. | |
| """ | |
| model_configs = { | |
| "MLP": {}, | |
| "SIREN": { | |
| "omega_0": 3. | |
| }, | |
| "WIRE": { | |
| "first_omega_0": 4., | |
| "hidden_omega_0": 4., | |
| "scale": 5., | |
| } | |
| } | |
| return model_configs | |
| model_configs = create_model_configs() | |
| def get_extra_model_cfg(model_name: str): | |
| """ | |
| Get the extra model configuration for a given model name. | |
| Args: | |
| model_name (str): Name of the model. | |
| Returns: | |
| dict: The extra model configuration. | |
| """ | |
| if model_name not in model_configs: | |
| raise ValueError(f"Model `{model_name}` is not supported. Available models are: {list(model_configs.keys())}") | |
| return model_configs[model_name] |