Spaces:
Running
Running
| import json | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from dynamix.dynamix import DynaMix | |
| import plotly.graph_objects as go | |
| import plotly.subplots as sp | |
| import numpy as np | |
| """ | |
| Loading models from HuggingFace Hub | |
| """ | |
| def load_hf_model_config(model_name): | |
| """Load model configuration from HuggingFace Hub""" | |
| config_path = hf_hub_download( | |
| repo_id="DurstewitzLab/dynamix", | |
| filename="config_" + model_name.replace("dynamix-", "") + ".json" | |
| ) | |
| with open(config_path, 'r') as f: | |
| model_config = json.load(f) | |
| return model_config | |
| def load_hf_model(model_name): | |
| """Load a specific DynaMix model with its configuration""" | |
| try: | |
| # Load model configuration | |
| model_config = load_hf_model_config(model_name) | |
| architecture = model_config["architecture"] | |
| # Extract hyperparameters from config | |
| M = architecture["M"] # Latent state dimension | |
| N = architecture["N"] # Observation space dimension | |
| EXPERTS = architecture["Experts"] # Number of experts | |
| P = architecture["P"] # Number of ReLU dimensions | |
| HIDDEN_DIM = architecture["hidden_dim"] | |
| expert_type = architecture["expert_type"] | |
| probabilistic_expert = architecture["probabilistic_expert"] | |
| # Create model with config parameters | |
| model = DynaMix( | |
| M=M, | |
| N=N, | |
| Experts=EXPERTS, | |
| expert_type=expert_type, | |
| P=P, | |
| hidden_dim=HIDDEN_DIM, | |
| probabilistic_expert=probabilistic_expert, | |
| ) | |
| # Load model weights | |
| model_path = hf_hub_download( | |
| repo_id="DurstewitzLab/dynamix", | |
| filename=model_name + ".safetensors", | |
| ) | |
| model_state_dict = load_file(model_path) | |
| model.load_state_dict(model_state_dict) | |
| model.eval() | |
| except Exception as e: | |
| print(f"Error loading model {model_name}: {e}") | |
| raise ValueError(f"Model {model_name} not found") | |
| return model | |
| # Model selection function | |
| def auto_model_selection(context): | |
| """ | |
| Select the model to use for forecasting | |
| """ | |
| if context.shape[1] == 1: | |
| return "dynamix-6d-alrnn-v1.0" | |
| elif context.shape[1] >= 2 and context.shape[1] <= 3: | |
| return "dynamix-3d-alrnn-v1.0" | |
| elif context.shape[1] >= 6: | |
| return "dynamix-6d-alrnn-v1.0" | |
| """ | |
| Plotting functions | |
| """ | |
| def create_forecast_plot(values, reconstruction_ts_np, horizon): | |
| """ | |
| Create a Plotly figure with dark theme styling matching the reference image | |
| """ | |
| dims = reconstruction_ts_np.shape[-1] | |
| plot_dims = min(dims, 15) # plot up to 15 dimensions | |
| context_time = np.arange(-len(values), 0) | |
| forecast_time = np.arange(0, int(horizon)) | |
| # Create subplots | |
| # Adjust spacing based on number of dimensions | |
| if plot_dims <= 3: | |
| vertical_spacing = 0.1 | |
| elif plot_dims <= 6: | |
| vertical_spacing = 0.05 | |
| elif plot_dims <= 15: | |
| vertical_spacing = 0.02 | |
| fig = sp.make_subplots( | |
| rows=plot_dims, | |
| cols=1, | |
| vertical_spacing=vertical_spacing | |
| ) | |
| # Add traces for each dimension | |
| for d in range(plot_dims): | |
| # Historical data | |
| historical_trace = go.Scatter( | |
| x=context_time, | |
| y=values[:, d], | |
| mode='lines', | |
| line=dict(color='#4169E1', width=2.5), | |
| name=f"context_{d+1}", | |
| showlegend=False, | |
| hovertemplate=f"context_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>" | |
| ) | |
| # Forecast | |
| forecast_trace = go.Scatter( | |
| x=forecast_time, | |
| y=reconstruction_ts_np[:, d], | |
| mode='lines', | |
| line=dict(color='#FF4242', width=2.5), | |
| name=f"forecast_{d+1}", | |
| showlegend=False, | |
| hovertemplate=f"forecast_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>" | |
| ) | |
| fig.add_trace(historical_trace, row=d+1, col=1) | |
| fig.add_trace(forecast_trace, row=d+1, col=1) | |
| fig.update_layout( | |
| plot_bgcolor='#1f2937', | |
| paper_bgcolor='#1f2937', | |
| font=dict(color='white'), | |
| showlegend=False, | |
| title=None, | |
| margin=dict(l=50, r=50, t=30, b=50), | |
| xaxis=dict( | |
| gridcolor='rgba(255, 255, 255, 0.2)', | |
| zerolinecolor='rgba(255, 255, 255, 0.2)', | |
| showgrid=True | |
| ), | |
| yaxis=dict( | |
| gridcolor='rgba(255, 255, 255, 0.2)', | |
| zerolinecolor='rgba(255, 255, 255, 0.2)', | |
| showgrid=True, | |
| ), | |
| height=300 if plot_dims == 1 else 250 * plot_dims, | |
| width=None | |
| ) | |
| for i in range(plot_dims): | |
| fig.update_xaxes( | |
| gridcolor='rgba(255, 255, 255, 0.2)', | |
| zerolinecolor='rgba(255, 255, 255, 0.2)', | |
| showgrid=True, | |
| row=i+1, col=1 | |
| ) | |
| fig.update_yaxes( | |
| gridcolor='rgba(255, 255, 255, 0.2)', | |
| zerolinecolor='rgba(255, 255, 255, 0.2)', | |
| showgrid=True, | |
| row=i+1, col=1 | |
| ) | |
| return fig |