Spaces:
Running
Running
File size: 5,203 Bytes
776877d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import json
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from dynamix.dynamix import DynaMix
import plotly.graph_objects as go
import plotly.subplots as sp
import numpy as np
"""
Loading models from HuggingFace Hub
"""
def load_hf_model_config(model_name):
"""Load model configuration from HuggingFace Hub"""
config_path = hf_hub_download(
repo_id="DurstewitzLab/dynamix",
filename="config_" + model_name.replace("dynamix-", "") + ".json"
)
with open(config_path, 'r') as f:
model_config = json.load(f)
return model_config
def load_hf_model(model_name):
"""Load a specific DynaMix model with its configuration"""
try:
# Load model configuration
model_config = load_hf_model_config(model_name)
architecture = model_config["architecture"]
# Extract hyperparameters from config
M = architecture["M"] # Latent state dimension
N = architecture["N"] # Observation space dimension
EXPERTS = architecture["Experts"] # Number of experts
P = architecture["P"] # Number of ReLU dimensions
HIDDEN_DIM = architecture["hidden_dim"]
expert_type = architecture["expert_type"]
probabilistic_expert = architecture["probabilistic_expert"]
# Create model with config parameters
model = DynaMix(
M=M,
N=N,
Experts=EXPERTS,
expert_type=expert_type,
P=P,
hidden_dim=HIDDEN_DIM,
probabilistic_expert=probabilistic_expert,
)
# Load model weights
model_path = hf_hub_download(
repo_id="DurstewitzLab/dynamix",
filename=model_name + ".safetensors",
)
model_state_dict = load_file(model_path)
model.load_state_dict(model_state_dict)
model.eval()
except Exception as e:
print(f"Error loading model {model_name}: {e}")
raise ValueError(f"Model {model_name} not found")
return model
# Model selection function
def auto_model_selection(context):
"""
Select the model to use for forecasting
"""
if context.shape[1] == 1:
return "dynamix-6d-alrnn-v1.0"
elif context.shape[1] >= 2 and context.shape[1] <= 3:
return "dynamix-3d-alrnn-v1.0"
elif context.shape[1] >= 6:
return "dynamix-6d-alrnn-v1.0"
"""
Plotting functions
"""
def create_forecast_plot(values, reconstruction_ts_np, horizon):
"""
Create a Plotly figure with dark theme styling matching the reference image
"""
dims = reconstruction_ts_np.shape[-1]
plot_dims = min(dims, 15) # plot up to 15 dimensions
context_time = np.arange(-len(values), 0)
forecast_time = np.arange(0, int(horizon))
# Create subplots
# Adjust spacing based on number of dimensions
if plot_dims <= 3:
vertical_spacing = 0.1
elif plot_dims <= 6:
vertical_spacing = 0.05
elif plot_dims <= 15:
vertical_spacing = 0.02
fig = sp.make_subplots(
rows=plot_dims,
cols=1,
vertical_spacing=vertical_spacing
)
# Add traces for each dimension
for d in range(plot_dims):
# Historical data
historical_trace = go.Scatter(
x=context_time,
y=values[:, d],
mode='lines',
line=dict(color='#4169E1', width=2.5),
name=f"context_{d+1}",
showlegend=False,
hovertemplate=f"context_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>"
)
# Forecast
forecast_trace = go.Scatter(
x=forecast_time,
y=reconstruction_ts_np[:, d],
mode='lines',
line=dict(color='#FF4242', width=2.5),
name=f"forecast_{d+1}",
showlegend=False,
hovertemplate=f"forecast_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>"
)
fig.add_trace(historical_trace, row=d+1, col=1)
fig.add_trace(forecast_trace, row=d+1, col=1)
fig.update_layout(
plot_bgcolor='#1f2937',
paper_bgcolor='#1f2937',
font=dict(color='white'),
showlegend=False,
title=None,
margin=dict(l=50, r=50, t=30, b=50),
xaxis=dict(
gridcolor='rgba(255, 255, 255, 0.2)',
zerolinecolor='rgba(255, 255, 255, 0.2)',
showgrid=True
),
yaxis=dict(
gridcolor='rgba(255, 255, 255, 0.2)',
zerolinecolor='rgba(255, 255, 255, 0.2)',
showgrid=True,
),
height=300 if plot_dims == 1 else 250 * plot_dims,
width=None
)
for i in range(plot_dims):
fig.update_xaxes(
gridcolor='rgba(255, 255, 255, 0.2)',
zerolinecolor='rgba(255, 255, 255, 0.2)',
showgrid=True,
row=i+1, col=1
)
fig.update_yaxes(
gridcolor='rgba(255, 255, 255, 0.2)',
zerolinecolor='rgba(255, 255, 255, 0.2)',
showgrid=True,
row=i+1, col=1
)
return fig |