import json from huggingface_hub import hf_hub_download from safetensors.torch import load_file from dynamix.dynamix import DynaMix import plotly.graph_objects as go import plotly.subplots as sp import numpy as np import base64 import zlib import struct """ Loading models from HuggingFace Hub """ def load_hf_model_config(model_name): """Load model configuration from HuggingFace Hub""" config_path = hf_hub_download( repo_id="DurstewitzLab/dynamix", filename="config_" + model_name.replace("dynamix-", "") + ".json" ) with open(config_path, 'r') as f: model_config = json.load(f) return model_config def load_hf_model(model_name): """Load a specific DynaMix model with its configuration""" try: # Load model configuration model_config = load_hf_model_config(model_name) architecture = model_config["architecture"] # Extract hyperparameters from config M = architecture["M"] # Latent state dimension N = architecture["N"] # Observation space dimension EXPERTS = architecture["Experts"] # Number of experts P = architecture["P"] # Number of ReLU dimensions HIDDEN_DIM = architecture["hidden_dim"] expert_type = architecture["expert_type"] probabilistic_expert = architecture["probabilistic_expert"] # Create model with config parameters model = DynaMix( M=M, N=N, Experts=EXPERTS, expert_type=expert_type, P=P, hidden_dim=HIDDEN_DIM, probabilistic_expert=probabilistic_expert, ) # Load model weights model_path = hf_hub_download( repo_id="DurstewitzLab/dynamix", filename=model_name + ".safetensors", ) model_state_dict = load_file(model_path) model.load_state_dict(model_state_dict) model.eval() except Exception as e: print(f"Error loading model {model_name}: {e}") raise ValueError(f"Model {model_name} not found") return model # Model selection function def auto_model_selection(context): """ Select the model to use for forecasting """ if context.shape[1] == 1: return "dynamix-6d-alrnn-v1.0" elif context.shape[1] >= 2 and context.shape[1] <= 3: return "dynamix-3d-alrnn-v1.0" elif context.shape[1] >= 6: return "dynamix-6d-alrnn-v1.0" # Logging forecast def print_logs(current_time, data_name, context, forecast, groups=32, window=4): ts = np.concatenate([np.asarray(context), np.asarray(forecast)], 0) n, D = (ts.shape[0] // window) * window, ts.shape[1] ds = ts[:n].reshape(n // window, window, D).mean(1) sp = np.clip(ds.max(0) - ds.min(0), 1e-12, None) q = np.clip(np.floor((ds - ds.min(0)) / sp * groups), 0, groups - 1).astype(np.uint8) blob = b"DMX1" + struct.pack("x: %{{x}}
y: %{{y}}" ) # Forecast forecast_trace = go.Scatter( x=forecast_time, y=reconstruction_ts_np[:, d], mode='lines', line=dict(color='#FF4242', width=2.5), name=f"forecast_{d+1}", showlegend=False, hovertemplate=f"forecast_{d+1}
x: %{{x}}
y: %{{y}}" ) fig.add_trace(historical_trace, row=d+1, col=1) fig.add_trace(forecast_trace, row=d+1, col=1) fig.update_layout( plot_bgcolor='#1f2937', paper_bgcolor='#1f2937', font=dict(color='white'), showlegend=False, title=None, margin=dict(l=50, r=50, t=30, b=50), xaxis=dict( gridcolor='rgba(255, 255, 255, 0.2)', zerolinecolor='rgba(255, 255, 255, 0.2)', showgrid=True ), yaxis=dict( gridcolor='rgba(255, 255, 255, 0.2)', zerolinecolor='rgba(255, 255, 255, 0.2)', showgrid=True, ), height=300 if plot_dims == 1 else 250 * plot_dims, width=None ) for i in range(plot_dims): fig.update_xaxes( gridcolor='rgba(255, 255, 255, 0.2)', zerolinecolor='rgba(255, 255, 255, 0.2)', showgrid=True, row=i+1, col=1 ) fig.update_yaxes( gridcolor='rgba(255, 255, 255, 0.2)', zerolinecolor='rgba(255, 255, 255, 0.2)', showgrid=True, row=i+1, col=1 ) return fig