mobius_markov / mobius_markov.py
1990two's picture
Update mobius_markov.py
b8eb72d verified
###########################################################################################################################################
#||||- - - |8.19.2025| - - - || MÖBIUS MARKOV || - - - |1990two| - - -|||| #
###########################################################################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional
SAFE_MIN = -1e6
SAFE_MAX = 1e6
EPS = 1e-8
#||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 𓅸 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX):
zero = torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
maxv = torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype)
tensor = torch.where(torch.isnan(tensor), zero, tensor)
tensor = torch.where(torch.isinf(tensor), maxv, tensor)
return torch.clamp(tensor, min_val, max_val)
def safe_complex_division(numerator, denominator, eps=EPS):
denominator_conj = torch.conj(denominator)
norm_sq = torch.real(denominator * denominator_conj)
norm_sq = torch.clamp(norm_sq, min=eps)
return (numerator * denominator_conj) / norm_sq
###########################################################################################################################################
####################################################- - - MÖBIUS TRANSFORM - - -#######################################################
class MobiusTransform(nn.Module):
def __init__(self, learnable=True, init_identity=True):
super().__init__()
self.learnable = learnable
if init_identity:
a_init, b_init, c_init, d_init = 1.0, 0.0, 0.0, 1.0
else:
a_init, d_init = 1.0, 1.0
b_init, c_init = 0.1, 0.1
if learnable:
self.a = nn.Parameter(torch.tensor([a_init, 0.0]))
self.b = nn.Parameter(torch.tensor([b_init, 0.0]))
self.c = nn.Parameter(torch.tensor([c_init, 0.0]))
self.d = nn.Parameter(torch.tensor([d_init, 0.0]))
else:
self.register_buffer('a', torch.tensor([a_init, 0.0]))
self.register_buffer('b', torch.tensor([b_init, 0.0]))
self.register_buffer('c', torch.tensor([c_init, 0.0]))
self.register_buffer('d', torch.tensor([d_init, 0.0]))
def to_complex(self, param):
return torch.complex(param[0], param[1])
def get_determinant(self):
a_complex = self.to_complex(self.a)
b_complex = self.to_complex(self.b)
c_complex = self.to_complex(self.c)
d_complex = self.to_complex(self.d)
det = a_complex * d_complex - b_complex * c_complex
return det
def normalize_parameters(self):
if self.learnable:
with torch.no_grad():
det = torch.abs(self.get_determinant())
if det < EPS:
one = torch.tensor([1.0, 0.0], device=self.a.device, dtype=self.a.dtype)
self.a.copy_(one)
self.d.copy_(one)
self.b.mul_(0.1)
self.c.mul_(0.1)
for p in (self.a, self.b, self.c, self.d):
p.clamp_(-10.0, 10.0)
def transform(self, z):
self.normalize_parameters()
a_complex = self.to_complex(self.a)
b_complex = self.to_complex(self.b)
c_complex = self.to_complex(self.c)
d_complex = self.to_complex(self.d)
numerator = a_complex * z + b_complex
denominator = c_complex * z + d_complex
transformed = safe_complex_division(numerator, denominator)
return transformed
def inverse_transform(self, w):
self.normalize_parameters()
a_complex = self.to_complex(self.a)
b_complex = self.to_complex(self.b)
c_complex = self.to_complex(self.c)
d_complex = self.to_complex(self.d)
numerator = d_complex * w - b_complex
denominator = -c_complex * w + a_complex
return safe_complex_division(numerator, denominator)
def get_transform_info(self):
det = self.get_determinant()
one = torch.tensor(1.0, device=det.device, dtype=det.real.dtype)
return {
'determinant': det,
'is_identity': torch.allclose(torch.abs(det), one, atol=1e-6),
'parameters': {'a': self.to_complex(self.a), 'b': self.to_complex(self.b),
'c': self.to_complex(self.c), 'd': self.to_complex(self.d)}
}
###########################################################################################################################################
#############################################- - - COMPLEX STATE MARKOV CHAIN - - -####################################################
class ComplexStateMarkovChain(nn.Module):
def __init__(self, num_states, state_embedding_dim=64, distance_kernel='gaussian'):
super().__init__()
self.num_states = num_states
self.state_embedding_dim = state_embedding_dim
self.distance_kernel = distance_kernel
self.state_positions = nn.Parameter(
torch.complex(
torch.randn(num_states) * 2.0,
torch.randn(num_states) * 2.0
)
)
self.state_embeddings = nn.Parameter(torch.randn(num_states, state_embedding_dim) * 0.1)
self.base_transition_logits = nn.Parameter(torch.randn(num_states, num_states) * 0.1)
self.distance_scale = nn.Parameter(torch.tensor(1.0))
self.distance_bias = nn.Parameter(torch.tensor(0.0))
if distance_kernel == 'gaussian':
self.kernel_width = nn.Parameter(torch.tensor(1.0))
elif distance_kernel == 'inverse':
self.kernel_power = nn.Parameter(torch.tensor(1.0))
def compute_transformed_distances(self, mobius_transform):
transformed_positions = mobius_transform.transform(self.state_positions)
pos_i = transformed_positions.unsqueeze(0) # [1, num_states]
pos_j = transformed_positions.unsqueeze(1) # [num_states, 1]
complex_diff = pos_i - pos_j
distances = torch.abs(complex_diff)
return distances, transformed_positions
def distance_to_probability(self, distances):
distances = torch.clamp(distances, min=EPS)
if self.distance_kernel == 'gaussian':
width = torch.clamp(self.kernel_width, min=0.1, max=10.0)
prob_contrib = torch.exp(-distances**2 / (2 * width**2))
elif self.distance_kernel == 'inverse':
power = torch.clamp(self.kernel_power, min=0.5, max=3.0)
prob_contrib = 1.0 / (distances**power + EPS)
else:
prob_contrib = torch.clamp(1.0 - distances, min=0.0)
return prob_contrib
def compute_transition_matrix(self, mobius_transform):
distances, transformed_positions = self.compute_transformed_distances(mobius_transform)
distance_contrib = self.distance_to_probability(distances)
scale = torch.clamp(self.distance_scale, min=0.1, max=10.0)
bias = torch.clamp(self.distance_bias, min=-5.0, max=5.0)
scaled_distance = scale * distance_contrib + bias
transition_logits = self.base_transition_logits + scaled_distance
transition_logits = transition_logits + torch.eye(self.num_states, device=transition_logits.device)*0.05
transition_matrix = F.softmax(transition_logits, dim=1)
return transition_matrix, transformed_positions
def forward(self, initial_state, num_steps, mobius_transform):
batch_size = initial_state.shape[0] if initial_state.dim() > 1 else 1
if initial_state.dim() == 1:
current_state = initial_state.unsqueeze(0)
else:
current_state = initial_state
transition_matrix, transformed_positions = self.compute_transition_matrix(mobius_transform)
trajectory = [current_state.clone()]
state_positions = [transformed_positions[current_state.argmax(dim=-1)]]
for step in range(num_steps):
current_state = torch.matmul(current_state, transition_matrix)
trajectory.append(current_state.clone())
most_likely_states = current_state.argmax(dim=-1)
state_positions.append(transformed_positions[most_likely_states])
return {
'trajectory': torch.stack(trajectory),
'final_state': current_state,
'state_positions': torch.stack(state_positions),
'transition_matrix': transition_matrix,
'transformed_positions': transformed_positions
}
###########################################################################################################################################
#############################################- - - MÖBIUS MARKOV SYSTEM - - -##########################################################
class MobiusMarkovSystem(nn.Module):
def __init__(self, num_states, state_embedding_dim=64, evolution_steps=10):
super().__init__()
self.num_states = num_states
self.evolution_steps = evolution_steps
self.mobius_transform = MobiusTransform(learnable=True, init_identity=True)
self.markov_chain = ComplexStateMarkovChain(num_states, state_embedding_dim)
self.mobius_evolution = nn.Sequential(
nn.Linear(state_embedding_dim, state_embedding_dim),
nn.Tanh(),
nn.Linear(state_embedding_dim, 8), # 4 complex parameters = 8 real values
)
self.state_encoder = nn.Sequential(
nn.Linear(num_states, state_embedding_dim),
nn.LayerNorm(state_embedding_dim),
nn.ReLU(),
nn.Linear(state_embedding_dim, state_embedding_dim)
)
self.state_decoder = nn.Sequential(
nn.Linear(state_embedding_dim, state_embedding_dim),
nn.ReLU(),
nn.Linear(state_embedding_dim, num_states),
nn.Softmax(dim=-1)
)
self.geometry_controller = nn.Parameter(torch.tensor(0.1))
def evolve_mobius_parameters(self, state_embedding):
evolution_signal = self.mobius_evolution(state_embedding)
evolution_rate = torch.clamp(self.geometry_controller, 0.01, 1.0)
if self.mobius_transform.learnable:
with torch.no_grad():
updates = (evolution_signal.view(4, 2) * evolution_rate * 0.01)\
.to(device=self.mobius_transform.a.device, dtype=self.mobius_transform.a.dtype)
self.mobius_transform.a.add_(updates[0])
self.mobius_transform.b.add_(updates[1])
self.mobius_transform.c.add_(updates[2])
self.mobius_transform.d.add_(updates[3])
self.mobius_transform.normalize_parameters()
def forward(self, initial_state, return_full_trajectory=False):
state_embedding = self.state_encoder(initial_state)
evolution_history = {
'states': [],
'geometries': [],
'transition_matrices': [],
'transformed_positions': []
}
current_state = initial_state
for step in range(self.evolution_steps):
state_embedding = self.state_encoder(current_state)
self.evolve_mobius_parameters(state_embedding.mean(dim=0))
markov_output = self.markov_chain.forward(
current_state,
num_steps=1,
mobius_transform=self.mobius_transform
)
current_state = markov_output['final_state']
if return_full_trajectory:
evolution_history['states'].append(current_state.clone())
evolution_history['geometries'].append(self.mobius_transform.get_transform_info())
evolution_history['transition_matrices'].append(markov_output['transition_matrix'])
evolution_history['transformed_positions'].append(markov_output['transformed_positions'])
final_embedding = self.state_encoder(current_state)
final_prediction = self.state_decoder(final_embedding)
output = {
'final_state': current_state,
'final_prediction': final_prediction,
'final_embedding': final_embedding,
'final_geometry': self.mobius_transform.get_transform_info()
}
if return_full_trajectory:
output['evolution_history'] = evolution_history
return output
def predict_sequence(self, initial_state, sequence_length):
predictions = []
current_state = initial_state
for _ in range(sequence_length):
output = self.forward(current_state)
predictions.append(output['final_prediction'])
current_state = output['final_state']
return torch.stack(predictions)
def get_system_info(self):
return {
'num_states': self.num_states,
'evolution_steps': self.evolution_steps,
'current_geometry': self.mobius_transform.get_transform_info(),
'state_positions': self.markov_chain.state_positions,
'geometry_evolution_rate': self.geometry_controller.item()
}