DynaMix / dynamix /utilities.py
Dschobby's picture
Upload 14 files
776877d verified
raw
history blame
5.2 kB
import json
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from dynamix.dynamix import DynaMix
import plotly.graph_objects as go
import plotly.subplots as sp
import numpy as np
"""
Loading models from HuggingFace Hub
"""
def load_hf_model_config(model_name):
"""Load model configuration from HuggingFace Hub"""
config_path = hf_hub_download(
repo_id="DurstewitzLab/dynamix",
filename="config_" + model_name.replace("dynamix-", "") + ".json"
)
with open(config_path, 'r') as f:
model_config = json.load(f)
return model_config
def load_hf_model(model_name):
"""Load a specific DynaMix model with its configuration"""
try:
# Load model configuration
model_config = load_hf_model_config(model_name)
architecture = model_config["architecture"]
# Extract hyperparameters from config
M = architecture["M"] # Latent state dimension
N = architecture["N"] # Observation space dimension
EXPERTS = architecture["Experts"] # Number of experts
P = architecture["P"] # Number of ReLU dimensions
HIDDEN_DIM = architecture["hidden_dim"]
expert_type = architecture["expert_type"]
probabilistic_expert = architecture["probabilistic_expert"]
# Create model with config parameters
model = DynaMix(
M=M,
N=N,
Experts=EXPERTS,
expert_type=expert_type,
P=P,
hidden_dim=HIDDEN_DIM,
probabilistic_expert=probabilistic_expert,
)
# Load model weights
model_path = hf_hub_download(
repo_id="DurstewitzLab/dynamix",
filename=model_name + ".safetensors",
)
model_state_dict = load_file(model_path)
model.load_state_dict(model_state_dict)
model.eval()
except Exception as e:
print(f"Error loading model {model_name}: {e}")
raise ValueError(f"Model {model_name} not found")
return model
# Model selection function
def auto_model_selection(context):
"""
Select the model to use for forecasting
"""
if context.shape[1] == 1:
return "dynamix-6d-alrnn-v1.0"
elif context.shape[1] >= 2 and context.shape[1] <= 3:
return "dynamix-3d-alrnn-v1.0"
elif context.shape[1] >= 6:
return "dynamix-6d-alrnn-v1.0"
"""
Plotting functions
"""
def create_forecast_plot(values, reconstruction_ts_np, horizon):
"""
Create a Plotly figure with dark theme styling matching the reference image
"""
dims = reconstruction_ts_np.shape[-1]
plot_dims = min(dims, 15) # plot up to 15 dimensions
context_time = np.arange(-len(values), 0)
forecast_time = np.arange(0, int(horizon))
# Create subplots
# Adjust spacing based on number of dimensions
if plot_dims <= 3:
vertical_spacing = 0.1
elif plot_dims <= 6:
vertical_spacing = 0.05
elif plot_dims <= 15:
vertical_spacing = 0.02
fig = sp.make_subplots(
rows=plot_dims,
cols=1,
vertical_spacing=vertical_spacing
)
# Add traces for each dimension
for d in range(plot_dims):
# Historical data
historical_trace = go.Scatter(
x=context_time,
y=values[:, d],
mode='lines',
line=dict(color='#4169E1', width=2.5),
name=f"context_{d+1}",
showlegend=False,
hovertemplate=f"context_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>"
)
# Forecast
forecast_trace = go.Scatter(
x=forecast_time,
y=reconstruction_ts_np[:, d],
mode='lines',
line=dict(color='#FF4242', width=2.5),
name=f"forecast_{d+1}",
showlegend=False,
hovertemplate=f"forecast_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>"
)
fig.add_trace(historical_trace, row=d+1, col=1)
fig.add_trace(forecast_trace, row=d+1, col=1)
fig.update_layout(
plot_bgcolor='#1f2937',
paper_bgcolor='#1f2937',
font=dict(color='white'),
showlegend=False,
title=None,
margin=dict(l=50, r=50, t=30, b=50),
xaxis=dict(
gridcolor='rgba(255, 255, 255, 0.2)',
zerolinecolor='rgba(255, 255, 255, 0.2)',
showgrid=True
),
yaxis=dict(
gridcolor='rgba(255, 255, 255, 0.2)',
zerolinecolor='rgba(255, 255, 255, 0.2)',
showgrid=True,
),
height=300 if plot_dims == 1 else 250 * plot_dims,
width=None
)
for i in range(plot_dims):
fig.update_xaxes(
gridcolor='rgba(255, 255, 255, 0.2)',
zerolinecolor='rgba(255, 255, 255, 0.2)',
showgrid=True,
row=i+1, col=1
)
fig.update_yaxes(
gridcolor='rgba(255, 255, 255, 0.2)',
zerolinecolor='rgba(255, 255, 255, 0.2)',
showgrid=True,
row=i+1, col=1
)
return fig