import json from huggingface_hub import hf_hub_download from safetensors.torch import load_file from dynamix.dynamix import DynaMix import plotly.graph_objects as go import plotly.subplots as sp import numpy as np """ Loading models from HuggingFace Hub """ def load_hf_model_config(model_name): """Load model configuration from HuggingFace Hub""" config_path = hf_hub_download( repo_id="DurstewitzLab/dynamix", filename="config_" + model_name.replace("dynamix-", "") + ".json" ) with open(config_path, 'r') as f: model_config = json.load(f) return model_config def load_hf_model(model_name): """Load a specific DynaMix model with its configuration""" try: # Load model configuration model_config = load_hf_model_config(model_name) architecture = model_config["architecture"] # Extract hyperparameters from config M = architecture["M"] # Latent state dimension N = architecture["N"] # Observation space dimension EXPERTS = architecture["Experts"] # Number of experts P = architecture["P"] # Number of ReLU dimensions HIDDEN_DIM = architecture["hidden_dim"] expert_type = architecture["expert_type"] probabilistic_expert = architecture["probabilistic_expert"] # Create model with config parameters model = DynaMix( M=M, N=N, Experts=EXPERTS, expert_type=expert_type, P=P, hidden_dim=HIDDEN_DIM, probabilistic_expert=probabilistic_expert, ) # Load model weights model_path = hf_hub_download( repo_id="DurstewitzLab/dynamix", filename=model_name + ".safetensors", ) model_state_dict = load_file(model_path) model.load_state_dict(model_state_dict) model.eval() except Exception as e: print(f"Error loading model {model_name}: {e}") raise ValueError(f"Model {model_name} not found") return model # Model selection function def auto_model_selection(context): """ Select the model to use for forecasting """ if context.shape[1] == 1: return "dynamix-6d-alrnn-v1.0" elif context.shape[1] >= 2 and context.shape[1] <= 3: return "dynamix-3d-alrnn-v1.0" elif context.shape[1] >= 6: return "dynamix-6d-alrnn-v1.0" """ Plotting functions """ def create_forecast_plot(values, reconstruction_ts_np, horizon): """ Create a Plotly figure with dark theme styling matching the reference image """ dims = reconstruction_ts_np.shape[-1] plot_dims = min(dims, 15) # plot up to 15 dimensions context_time = np.arange(-len(values), 0) forecast_time = np.arange(0, int(horizon)) # Create subplots # Adjust spacing based on number of dimensions if plot_dims <= 3: vertical_spacing = 0.1 elif plot_dims <= 6: vertical_spacing = 0.05 elif plot_dims <= 15: vertical_spacing = 0.02 fig = sp.make_subplots( rows=plot_dims, cols=1, vertical_spacing=vertical_spacing ) # Add traces for each dimension for d in range(plot_dims): # Historical data historical_trace = go.Scatter( x=context_time, y=values[:, d], mode='lines', line=dict(color='#4169E1', width=2.5), name=f"context_{d+1}", showlegend=False, hovertemplate=f"context_{d+1}
x: %{{x}}
y: %{{y}}" ) # Forecast forecast_trace = go.Scatter( x=forecast_time, y=reconstruction_ts_np[:, d], mode='lines', line=dict(color='#FF4242', width=2.5), name=f"forecast_{d+1}", showlegend=False, hovertemplate=f"forecast_{d+1}
x: %{{x}}
y: %{{y}}" ) fig.add_trace(historical_trace, row=d+1, col=1) fig.add_trace(forecast_trace, row=d+1, col=1) fig.update_layout( plot_bgcolor='#1f2937', paper_bgcolor='#1f2937', font=dict(color='white'), showlegend=False, title=None, margin=dict(l=50, r=50, t=30, b=50), xaxis=dict( gridcolor='rgba(255, 255, 255, 0.2)', zerolinecolor='rgba(255, 255, 255, 0.2)', showgrid=True ), yaxis=dict( gridcolor='rgba(255, 255, 255, 0.2)', zerolinecolor='rgba(255, 255, 255, 0.2)', showgrid=True, ), height=300 if plot_dims == 1 else 250 * plot_dims, width=None ) for i in range(plot_dims): fig.update_xaxes( gridcolor='rgba(255, 255, 255, 0.2)', zerolinecolor='rgba(255, 255, 255, 0.2)', showgrid=True, row=i+1, col=1 ) fig.update_yaxes( gridcolor='rgba(255, 255, 255, 0.2)', zerolinecolor='rgba(255, 255, 255, 0.2)', showgrid=True, row=i+1, col=1 ) return fig