|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import math |
|
|
import matplotlib.pyplot as plt |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
|
|
|
SAFE_MIN = -1e6 |
|
|
SAFE_MAX = 1e6 |
|
|
EPS = 1e-8 |
|
|
|
|
|
|
|
|
|
|
|
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX): |
|
|
zero = torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype) |
|
|
maxv = torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype) |
|
|
tensor = torch.where(torch.isnan(tensor), zero, tensor) |
|
|
tensor = torch.where(torch.isinf(tensor), maxv, tensor) |
|
|
return torch.clamp(tensor, min_val, max_val) |
|
|
|
|
|
|
|
|
def safe_complex_division(numerator, denominator, eps=EPS): |
|
|
denominator_conj = torch.conj(denominator) |
|
|
norm_sq = torch.real(denominator * denominator_conj) |
|
|
norm_sq = torch.clamp(norm_sq, min=eps) |
|
|
return (numerator * denominator_conj) / norm_sq |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MobiusTransform(nn.Module): |
|
|
def __init__(self, learnable=True, init_identity=True): |
|
|
super().__init__() |
|
|
self.learnable = learnable |
|
|
|
|
|
if init_identity: |
|
|
a_init, b_init, c_init, d_init = 1.0, 0.0, 0.0, 1.0 |
|
|
else: |
|
|
a_init, d_init = 1.0, 1.0 |
|
|
b_init, c_init = 0.1, 0.1 |
|
|
|
|
|
if learnable: |
|
|
self.a = nn.Parameter(torch.tensor([a_init, 0.0])) |
|
|
self.b = nn.Parameter(torch.tensor([b_init, 0.0])) |
|
|
self.c = nn.Parameter(torch.tensor([c_init, 0.0])) |
|
|
self.d = nn.Parameter(torch.tensor([d_init, 0.0])) |
|
|
else: |
|
|
self.register_buffer('a', torch.tensor([a_init, 0.0])) |
|
|
self.register_buffer('b', torch.tensor([b_init, 0.0])) |
|
|
self.register_buffer('c', torch.tensor([c_init, 0.0])) |
|
|
self.register_buffer('d', torch.tensor([d_init, 0.0])) |
|
|
|
|
|
def to_complex(self, param): |
|
|
return torch.complex(param[0], param[1]) |
|
|
|
|
|
def get_determinant(self): |
|
|
a_complex = self.to_complex(self.a) |
|
|
b_complex = self.to_complex(self.b) |
|
|
c_complex = self.to_complex(self.c) |
|
|
d_complex = self.to_complex(self.d) |
|
|
|
|
|
det = a_complex * d_complex - b_complex * c_complex |
|
|
return det |
|
|
|
|
|
def normalize_parameters(self): |
|
|
if self.learnable: |
|
|
with torch.no_grad(): |
|
|
det = torch.abs(self.get_determinant()) |
|
|
if det < EPS: |
|
|
one = torch.tensor([1.0, 0.0], device=self.a.device, dtype=self.a.dtype) |
|
|
self.a.copy_(one) |
|
|
self.d.copy_(one) |
|
|
self.b.mul_(0.1) |
|
|
self.c.mul_(0.1) |
|
|
for p in (self.a, self.b, self.c, self.d): |
|
|
p.clamp_(-10.0, 10.0) |
|
|
|
|
|
|
|
|
def transform(self, z): |
|
|
self.normalize_parameters() |
|
|
|
|
|
a_complex = self.to_complex(self.a) |
|
|
b_complex = self.to_complex(self.b) |
|
|
c_complex = self.to_complex(self.c) |
|
|
d_complex = self.to_complex(self.d) |
|
|
|
|
|
numerator = a_complex * z + b_complex |
|
|
denominator = c_complex * z + d_complex |
|
|
transformed = safe_complex_division(numerator, denominator) |
|
|
|
|
|
return transformed |
|
|
|
|
|
def inverse_transform(self, w): |
|
|
self.normalize_parameters() |
|
|
|
|
|
a_complex = self.to_complex(self.a) |
|
|
b_complex = self.to_complex(self.b) |
|
|
c_complex = self.to_complex(self.c) |
|
|
d_complex = self.to_complex(self.d) |
|
|
|
|
|
numerator = d_complex * w - b_complex |
|
|
denominator = -c_complex * w + a_complex |
|
|
|
|
|
return safe_complex_division(numerator, denominator) |
|
|
|
|
|
def get_transform_info(self): |
|
|
det = self.get_determinant() |
|
|
one = torch.tensor(1.0, device=det.device, dtype=det.real.dtype) |
|
|
return { |
|
|
'determinant': det, |
|
|
'is_identity': torch.allclose(torch.abs(det), one, atol=1e-6), |
|
|
'parameters': {'a': self.to_complex(self.a), 'b': self.to_complex(self.b), |
|
|
'c': self.to_complex(self.c), 'd': self.to_complex(self.d)} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ComplexStateMarkovChain(nn.Module): |
|
|
def __init__(self, num_states, state_embedding_dim=64, distance_kernel='gaussian'): |
|
|
super().__init__() |
|
|
self.num_states = num_states |
|
|
self.state_embedding_dim = state_embedding_dim |
|
|
self.distance_kernel = distance_kernel |
|
|
|
|
|
self.state_positions = nn.Parameter( |
|
|
torch.complex( |
|
|
torch.randn(num_states) * 2.0, |
|
|
torch.randn(num_states) * 2.0 |
|
|
) |
|
|
) |
|
|
|
|
|
self.state_embeddings = nn.Parameter(torch.randn(num_states, state_embedding_dim) * 0.1) |
|
|
|
|
|
self.base_transition_logits = nn.Parameter(torch.randn(num_states, num_states) * 0.1) |
|
|
self.distance_scale = nn.Parameter(torch.tensor(1.0)) |
|
|
self.distance_bias = nn.Parameter(torch.tensor(0.0)) |
|
|
|
|
|
if distance_kernel == 'gaussian': |
|
|
self.kernel_width = nn.Parameter(torch.tensor(1.0)) |
|
|
elif distance_kernel == 'inverse': |
|
|
self.kernel_power = nn.Parameter(torch.tensor(1.0)) |
|
|
|
|
|
def compute_transformed_distances(self, mobius_transform): |
|
|
transformed_positions = mobius_transform.transform(self.state_positions) |
|
|
|
|
|
pos_i = transformed_positions.unsqueeze(0) |
|
|
pos_j = transformed_positions.unsqueeze(1) |
|
|
|
|
|
complex_diff = pos_i - pos_j |
|
|
distances = torch.abs(complex_diff) |
|
|
|
|
|
return distances, transformed_positions |
|
|
|
|
|
def distance_to_probability(self, distances): |
|
|
distances = torch.clamp(distances, min=EPS) |
|
|
|
|
|
if self.distance_kernel == 'gaussian': |
|
|
width = torch.clamp(self.kernel_width, min=0.1, max=10.0) |
|
|
prob_contrib = torch.exp(-distances**2 / (2 * width**2)) |
|
|
elif self.distance_kernel == 'inverse': |
|
|
power = torch.clamp(self.kernel_power, min=0.5, max=3.0) |
|
|
prob_contrib = 1.0 / (distances**power + EPS) |
|
|
else: |
|
|
prob_contrib = torch.clamp(1.0 - distances, min=0.0) |
|
|
|
|
|
return prob_contrib |
|
|
|
|
|
def compute_transition_matrix(self, mobius_transform): |
|
|
distances, transformed_positions = self.compute_transformed_distances(mobius_transform) |
|
|
|
|
|
distance_contrib = self.distance_to_probability(distances) |
|
|
|
|
|
scale = torch.clamp(self.distance_scale, min=0.1, max=10.0) |
|
|
bias = torch.clamp(self.distance_bias, min=-5.0, max=5.0) |
|
|
scaled_distance = scale * distance_contrib + bias |
|
|
|
|
|
transition_logits = self.base_transition_logits + scaled_distance |
|
|
transition_logits = transition_logits + torch.eye(self.num_states, device=transition_logits.device)*0.05 |
|
|
|
|
|
transition_matrix = F.softmax(transition_logits, dim=1) |
|
|
|
|
|
return transition_matrix, transformed_positions |
|
|
|
|
|
def forward(self, initial_state, num_steps, mobius_transform): |
|
|
batch_size = initial_state.shape[0] if initial_state.dim() > 1 else 1 |
|
|
|
|
|
if initial_state.dim() == 1: |
|
|
current_state = initial_state.unsqueeze(0) |
|
|
else: |
|
|
current_state = initial_state |
|
|
|
|
|
transition_matrix, transformed_positions = self.compute_transition_matrix(mobius_transform) |
|
|
|
|
|
trajectory = [current_state.clone()] |
|
|
state_positions = [transformed_positions[current_state.argmax(dim=-1)]] |
|
|
|
|
|
for step in range(num_steps): |
|
|
current_state = torch.matmul(current_state, transition_matrix) |
|
|
trajectory.append(current_state.clone()) |
|
|
|
|
|
most_likely_states = current_state.argmax(dim=-1) |
|
|
state_positions.append(transformed_positions[most_likely_states]) |
|
|
|
|
|
return { |
|
|
'trajectory': torch.stack(trajectory), |
|
|
'final_state': current_state, |
|
|
'state_positions': torch.stack(state_positions), |
|
|
'transition_matrix': transition_matrix, |
|
|
'transformed_positions': transformed_positions |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MobiusMarkovSystem(nn.Module): |
|
|
def __init__(self, num_states, state_embedding_dim=64, evolution_steps=10): |
|
|
super().__init__() |
|
|
self.num_states = num_states |
|
|
self.evolution_steps = evolution_steps |
|
|
|
|
|
self.mobius_transform = MobiusTransform(learnable=True, init_identity=True) |
|
|
self.markov_chain = ComplexStateMarkovChain(num_states, state_embedding_dim) |
|
|
|
|
|
self.mobius_evolution = nn.Sequential( |
|
|
nn.Linear(state_embedding_dim, state_embedding_dim), |
|
|
nn.Tanh(), |
|
|
nn.Linear(state_embedding_dim, 8), |
|
|
) |
|
|
|
|
|
self.state_encoder = nn.Sequential( |
|
|
nn.Linear(num_states, state_embedding_dim), |
|
|
nn.LayerNorm(state_embedding_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(state_embedding_dim, state_embedding_dim) |
|
|
) |
|
|
|
|
|
self.state_decoder = nn.Sequential( |
|
|
nn.Linear(state_embedding_dim, state_embedding_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(state_embedding_dim, num_states), |
|
|
nn.Softmax(dim=-1) |
|
|
) |
|
|
|
|
|
self.geometry_controller = nn.Parameter(torch.tensor(0.1)) |
|
|
|
|
|
def evolve_mobius_parameters(self, state_embedding): |
|
|
evolution_signal = self.mobius_evolution(state_embedding) |
|
|
evolution_rate = torch.clamp(self.geometry_controller, 0.01, 1.0) |
|
|
if self.mobius_transform.learnable: |
|
|
with torch.no_grad(): |
|
|
updates = (evolution_signal.view(4, 2) * evolution_rate * 0.01)\ |
|
|
.to(device=self.mobius_transform.a.device, dtype=self.mobius_transform.a.dtype) |
|
|
self.mobius_transform.a.add_(updates[0]) |
|
|
self.mobius_transform.b.add_(updates[1]) |
|
|
self.mobius_transform.c.add_(updates[2]) |
|
|
self.mobius_transform.d.add_(updates[3]) |
|
|
self.mobius_transform.normalize_parameters() |
|
|
|
|
|
|
|
|
def forward(self, initial_state, return_full_trajectory=False): |
|
|
state_embedding = self.state_encoder(initial_state) |
|
|
|
|
|
evolution_history = { |
|
|
'states': [], |
|
|
'geometries': [], |
|
|
'transition_matrices': [], |
|
|
'transformed_positions': [] |
|
|
} |
|
|
|
|
|
current_state = initial_state |
|
|
|
|
|
for step in range(self.evolution_steps): |
|
|
state_embedding = self.state_encoder(current_state) |
|
|
|
|
|
self.evolve_mobius_parameters(state_embedding.mean(dim=0)) |
|
|
|
|
|
markov_output = self.markov_chain.forward( |
|
|
current_state, |
|
|
num_steps=1, |
|
|
mobius_transform=self.mobius_transform |
|
|
) |
|
|
|
|
|
current_state = markov_output['final_state'] |
|
|
|
|
|
if return_full_trajectory: |
|
|
evolution_history['states'].append(current_state.clone()) |
|
|
evolution_history['geometries'].append(self.mobius_transform.get_transform_info()) |
|
|
evolution_history['transition_matrices'].append(markov_output['transition_matrix']) |
|
|
evolution_history['transformed_positions'].append(markov_output['transformed_positions']) |
|
|
|
|
|
final_embedding = self.state_encoder(current_state) |
|
|
final_prediction = self.state_decoder(final_embedding) |
|
|
|
|
|
output = { |
|
|
'final_state': current_state, |
|
|
'final_prediction': final_prediction, |
|
|
'final_embedding': final_embedding, |
|
|
'final_geometry': self.mobius_transform.get_transform_info() |
|
|
} |
|
|
|
|
|
if return_full_trajectory: |
|
|
output['evolution_history'] = evolution_history |
|
|
|
|
|
return output |
|
|
|
|
|
def predict_sequence(self, initial_state, sequence_length): |
|
|
predictions = [] |
|
|
current_state = initial_state |
|
|
|
|
|
for _ in range(sequence_length): |
|
|
output = self.forward(current_state) |
|
|
predictions.append(output['final_prediction']) |
|
|
current_state = output['final_state'] |
|
|
|
|
|
return torch.stack(predictions) |
|
|
|
|
|
def get_system_info(self): |
|
|
return { |
|
|
'num_states': self.num_states, |
|
|
'evolution_steps': self.evolution_steps, |
|
|
'current_geometry': self.mobius_transform.get_transform_info(), |
|
|
'state_positions': self.markov_chain.state_positions, |
|
|
'geometry_evolution_rate': self.geometry_controller.item() |
|
|
} |
|
|
|
|
|
|