liquid_bayes / liquid_bayes.py
1990two's picture
Update liquid_bayes.py
a9321fe verified
###########################################################################################################################################
#||||- - - |8.19.2025| - - - || LIQUID BAYES || - - - |1990two| - - -|||| #
###########################################################################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from collections import defaultdict
from typing import List, Dict, Tuple, Optional
SAFE_MIN = -1e6
SAFE_MAX = 1e6
EPS = 1e-8
#||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 𓅸 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX):
tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype), tensor)
tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype), tensor)
return torch.clamp(tensor, min_val, max_val)
def safe_softmax(x, dim=-1, temperature=1.0):
x = x.to(dtype=torch.float32)
x = make_safe(x, min_val=-50, max_val=50)
if isinstance(temperature, torch.Tensor):
temperature = float(temperature.detach().cpu().item())
temperature = max(float(temperature), EPS)
x = x / temperature
x = x - x.amax(dim=dim, keepdim=True)
return F.softmax(x, dim=dim)
###########################################################################################################################################
#################################################- - - LIQUID DYNAMICS CORE - - -######################################################
class LiquidDynamicsCore(nn.Module):
def __init__(self, state_dim, input_dim, liquid_time_constant=1.0):
super().__init__()
self.state_dim = state_dim
self.input_dim = input_dim
self.liquid_time_constant = nn.Parameter(torch.tensor(liquid_time_constant))
self.W_rec = nn.Parameter(torch.randn(state_dim, state_dim) * 0.1) # Recurrent weights
self.W_in = nn.Parameter(torch.randn(state_dim, input_dim) * 0.1) # Input weights
self.bias = nn.Parameter(torch.zeros(state_dim))
self.activation = nn.Tanh()
self.register_buffer('liquid_state', torch.zeros(1, state_dim))
self.noise_scale = nn.Parameter(torch.tensor(0.1))
self.exploration_rate = nn.Parameter(torch.tensor(0.05))
def reset_state(self, batch_size=1):
with torch.no_grad():
if self.liquid_state.shape[0] != batch_size:
self.liquid_state = torch.zeros(
batch_size, self.state_dim,
device=self.liquid_state.device,
dtype=self.liquid_state.dtype,
)
else:
self.liquid_state.zero_()
def evolve_liquid(self, input_signal, confidence_weight=1.0, dt=0.1):
batch_size = input_signal.shape[0]
if self.liquid_state.shape[0] != batch_size:
self.reset_state(batch_size)
tau = torch.clamp(self.liquid_time_constant, 0.1, 10.0)
recurrent_input = torch.matmul(self.activation(self.liquid_state), self.W_rec.T)
external_input = torch.matmul(input_signal, self.W_in.T)
dynamics = (-self.liquid_state / tau + recurrent_input + external_input + self.bias)
if isinstance(confidence_weight, torch.Tensor):
if confidence_weight.dim() == 1:
confidence_weight = confidence_weight.unsqueeze(-1)
confidence_weight = confidence_weight.to(self.liquid_state.dtype)
else:
confidence_weight = torch.tensor(confidence_weight, device=self.liquid_state.device, dtype=self.liquid_state.dtype)
exploration_noise = torch.randn_like(self.liquid_state) * self.noise_scale
exploration_strength = (1.0 - confidence_weight) * self.exploration_rate
modulated_dynamics = confidence_weight * dynamics + exploration_strength * exploration_noise
self.liquid_state.add_(dt * make_safe(modulated_dynamics))
return self.liquid_state.clone()
def get_liquid_features(self):
return {
'raw_state': self.liquid_state.clone(),
'activated_state': self.activation(self.liquid_state),
'state_energy': torch.sum(self.liquid_state ** 2, dim=-1, keepdim=True),
'state_entropy': self._compute_state_entropy()
}
def _compute_state_entropy(self):
state_probs = safe_softmax(self.liquid_state, dim=-1, temperature=1.0)
entropy = -torch.sum(state_probs * torch.log(state_probs + EPS), dim=-1, keepdim=True)
return entropy
###########################################################################################################################################
############################################- - - BAYESIAN CONFIDENCE NETWORK - - -####################################################
class BayesianConfidenceNetwork(nn.Module):
def __init__(self, state_dim, num_variables=5, num_states_per_var=3):
super().__init__()
self.state_dim = state_dim
self.num_variables = num_variables
self.num_states_per_var = num_states_per_var
self.feature_extractor = nn.Sequential(
nn.Linear(state_dim, state_dim * 2),
nn.LayerNorm(state_dim * 2),
nn.ReLU(),
nn.Linear(state_dim * 2, num_variables * num_states_per_var)
)
self.conditional_prob_tables = nn.ParameterList([
nn.Parameter(torch.randn(num_states_per_var, num_states_per_var * (num_variables - 1)) * 0.1)
for _ in range(num_variables)
])
self.priors = nn.Parameter(torch.ones(num_variables, num_states_per_var))
self.confidence_net = nn.Sequential(
nn.Linear(num_variables, num_variables * 2),
nn.ReLU(),
nn.Linear(num_variables * 2, 1),
nn.Sigmoid()
)
self.uncertainty_estimator = nn.Sequential(
nn.Linear(state_dim, state_dim),
nn.ReLU(),
nn.Linear(state_dim, 1),
nn.Sigmoid()
)
def extract_variable_beliefs(self, liquid_features):
liquid_state = liquid_features['activated_state']
evidence = self.feature_extractor(liquid_state)
evidence = evidence.view(-1, self.num_variables, self.num_states_per_var)
variable_beliefs = safe_softmax(evidence, dim=-1)
return variable_beliefs
def bayesian_inference(self, variable_beliefs):
batch_size = variable_beliefs.shape[0]
device = variable_beliefs.device
current_beliefs = safe_softmax(self.priors.unsqueeze(0).expand(batch_size, -1, -1), dim=-1)
for iteration in range(3): # Few iterations for efficiency
new_beliefs = current_beliefs.clone()
for var_idx in range(self.num_variables):
evidence = variable_beliefs[:, var_idx, :]
if self.num_variables > 1:
other_var_beliefs = torch.cat([
current_beliefs[:, :var_idx].flatten(1),
current_beliefs[:, var_idx+1:].flatten(1)
], dim=1)
else:
other_var_beliefs = torch.zeros(batch_size, 0, device=device)
if other_var_beliefs.shape[1] > 0:
cond_probs = torch.matmul(other_var_beliefs, self.conditional_prob_tables[var_idx].T)
cond_probs = safe_softmax(cond_probs, dim=-1)
else:
cond_probs = torch.ones_like(evidence) / self.num_states_per_var
combined = evidence * cond_probs
new_beliefs[:, var_idx, :] = safe_softmax(combined, dim=-1)
current_beliefs = new_beliefs
return current_beliefs
def compute_confidence(self, beliefs, liquid_features):
belief_entropy = -torch.sum(beliefs * torch.log(beliefs + EPS), dim=-1)
avg_entropy = belief_entropy.mean(dim=-1, keepdim=True)
max_entropy = math.log(self.num_states_per_var)
entropy_confidence = 1.0 - (avg_entropy / max_entropy)
nn_confidence = self.confidence_net(belief_entropy)
liquid_uncertainty = self.uncertainty_estimator(liquid_features['raw_state'])
state_confidence = 1.0 - liquid_uncertainty
total_confidence = 0.4 * entropy_confidence + 0.3 * nn_confidence + 0.3 * state_confidence
return torch.clamp(total_confidence, 0.0, 1.0)
def forward(self, liquid_features):
variable_beliefs = self.extract_variable_beliefs(liquid_features)
posterior_beliefs = self.bayesian_inference(variable_beliefs)
confidence = self.compute_confidence(posterior_beliefs, liquid_features)
return {
'beliefs': posterior_beliefs,
'confidence': confidence,
'variable_beliefs': variable_beliefs
}
###########################################################################################################################################
############################################- - - LIQUID BAYES CHAIN - - -############################################################
class LiquidBayesChain(nn.Module):
def __init__(self, input_dim, state_dim, output_dim, num_chain_steps=3):
super().__init__()
self.input_dim = input_dim
self.state_dim = state_dim
self.output_dim = output_dim
self.num_chain_steps = num_chain_steps
self.liquid_core = LiquidDynamicsCore(state_dim, input_dim)
self.bayesian_confidence = BayesianConfidenceNetwork(state_dim)
self.final_predictor = nn.Sequential(
nn.Linear(state_dim, state_dim * 2),
nn.LayerNorm(state_dim * 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(state_dim * 2, output_dim)
)
self.final_bayesian = BayesianConfidenceNetwork(output_dim, num_variables=3, num_states_per_var=4)
self.step_weights = nn.Parameter(torch.ones(num_chain_steps))
def single_chain_step(self, input_signal, step_idx=0):
if step_idx == 0:
liquid_state = self.liquid_core.evolve_liquid(input_signal, confidence_weight=1.0)
else:
liquid_features = self.liquid_core.get_liquid_features()
bayes_output = self.bayesian_confidence(liquid_features)
confidence = bayes_output['confidence']
liquid_state = self.liquid_core.evolve_liquid(input_signal, confidence_weight=confidence)
liquid_features = self.liquid_core.get_liquid_features()
bayes_output = self.bayesian_confidence(liquid_features)
return {
'liquid_state': liquid_state,
'liquid_features': liquid_features,
'bayes_output': bayes_output,
'confidence': bayes_output['confidence']
}
def forward(self, input_signal, return_chain_states=False):
batch_size = input_signal.shape[0]
self.liquid_core.reset_state(batch_size)
chain_states = []
for step in range(self.num_chain_steps):
step_output = self.single_chain_step(input_signal, step_idx=step)
step_output['step_idx'] = step
chain_states.append(step_output)
final_liquid_state = chain_states[-1]['liquid_features']['activated_state']
prediction_logits = self.final_predictor(final_liquid_state)
prediction_features = {
'raw_state': prediction_logits,
'activated_state': torch.tanh(prediction_logits)
}
final_bayes = self.final_bayesian(prediction_features)
step_weights = safe_softmax(self.step_weights, dim=0)
weighted_confidence = sum(
step_weights[i] * chain_states[i]['confidence']
for i in range(self.num_chain_steps)
)
output = {
'prediction': prediction_logits,
'final_confidence': weighted_confidence,
'final_beliefs': final_bayes['beliefs'],
'prediction_uncertainty': 1.0 - final_bayes['confidence']
}
if return_chain_states:
output['chain_states'] = chain_states
return output
def predict_with_uncertainty(self, input_signal):
output = self.forward(input_signal, return_chain_states=True)
uncertainty_info = {
'prediction': output['prediction'],
'confidence': output['final_confidence'],
'prediction_uncertainty': output['prediction_uncertainty'],
'chain_confidences': [state['confidence'] for state in output['chain_states']],
'liquid_entropies': [state['liquid_features']['state_entropy'] for state in output['chain_states']]
}
return uncertainty_info
###########################################################################################################################################