TORI / coherence_monitor.py
Personaz1
ΔΣ::TORI - Copy TORUS modules locally for Hugging Face deployment
247545d
"""
Coherence Monitoring and Self-Reflection Module
This module implements the coherence assessment and self-reflection mechanisms
that are central to the toroidal diffusion model architecture.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional
from collections import deque
class CoherenceMetrics:
"""
Computes various coherence metrics for assessing generation quality.
"""
@staticmethod
def semantic_coherence(features: torch.Tensor, window_size: int = 3) -> torch.Tensor:
"""
Compute semantic coherence based on local feature consistency.
Args:
features: Feature tensor of shape (batch, channels, height, width)
window_size: Size of the local window for coherence computation
Returns:
coherence: Semantic coherence score
"""
batch_size, channels, height, width = features.shape
# Compute local variance within windows
kernel = torch.ones(1, 1, window_size, window_size, device=features.device) / (window_size ** 2)
# Mean within windows
local_mean = F.conv2d(features, kernel.repeat(channels, 1, 1, 1),
groups=channels, padding=window_size//2)
# Variance within windows
local_var = F.conv2d((features - local_mean) ** 2, kernel.repeat(channels, 1, 1, 1),
groups=channels, padding=window_size//2)
# Coherence is inverse of variance (lower variance = higher coherence)
coherence = 1.0 / (1.0 + local_var.mean(dim=1, keepdim=True))
return coherence
@staticmethod
def structural_coherence(features: torch.Tensor) -> torch.Tensor:
"""
Compute structural coherence based on gradient consistency.
Args:
features: Feature tensor
Returns:
coherence: Structural coherence score
"""
# Compute gradients
grad_x = torch.diff(features, dim=3, prepend=features[:, :, :, -1:])
grad_y = torch.diff(features, dim=2, prepend=features[:, :, -1:, :])
# Gradient magnitude
grad_mag = torch.sqrt(grad_x ** 2 + grad_y ** 2)
# Coherence based on gradient smoothness
grad_smoothness = 1.0 / (1.0 + torch.std(grad_mag, dim=1, keepdim=True))
return grad_smoothness
@staticmethod
def temporal_coherence(features_sequence: List[torch.Tensor]) -> torch.Tensor:
"""
Compute temporal coherence across a sequence of features.
Args:
features_sequence: List of feature tensors from different timesteps
Returns:
coherence: Temporal coherence score
"""
if len(features_sequence) < 2:
return torch.ones_like(features_sequence[0][:, :1])
# Compute frame-to-frame differences
temporal_diffs = []
for i in range(1, len(features_sequence)):
diff = torch.abs(features_sequence[i] - features_sequence[i-1])
temporal_diffs.append(diff.mean(dim=1, keepdim=True))
# Average temporal difference
avg_temporal_diff = torch.stack(temporal_diffs).mean(dim=0)
# Coherence is inverse of temporal variation
temporal_coherence = 1.0 / (1.0 + avg_temporal_diff)
return temporal_coherence
class SelfReflectionModule(nn.Module):
"""
Implements self-reflection mechanisms for the toroidal diffusion model.
This module analyzes the current generation state and provides feedback
for improving coherence and quality.
"""
def __init__(self, feature_dim: int, reflection_depth: int = 3):
super().__init__()
self.feature_dim = feature_dim
self.reflection_depth = reflection_depth
# Reflection network layers
num_groups = min(8, feature_dim) if feature_dim >= 8 else 1
self.reflection_layers = nn.ModuleList([
nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, 3, padding=1),
nn.GroupNorm(num_groups, feature_dim),
nn.SiLU(),
nn.Conv2d(feature_dim, feature_dim, 3, padding=1),
nn.GroupNorm(num_groups, feature_dim),
nn.SiLU()
) for _ in range(reflection_depth)
])
# Coherence assessment head
self.coherence_head = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim // 2, 1),
nn.SiLU(),
nn.Conv2d(feature_dim // 2, 1, 1),
nn.Sigmoid()
)
# Correction suggestion head
self.correction_head = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, 3, padding=1),
nn.GroupNorm(num_groups, feature_dim),
nn.SiLU(),
nn.Conv2d(feature_dim, feature_dim, 3, padding=1)
)
def analyze_coherence(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Analyze the coherence of current features.
Args:
features: Input feature tensor
Returns:
analysis: Dictionary containing coherence metrics
"""
semantic_coh = CoherenceMetrics.semantic_coherence(features)
structural_coh = CoherenceMetrics.structural_coherence(features)
# Overall coherence score
overall_coherence = self.coherence_head(features)
return {
'semantic_coherence': semantic_coh,
'structural_coherence': structural_coh,
'overall_coherence': overall_coherence,
'mean_coherence': (semantic_coh + structural_coh + overall_coherence) / 3
}
def generate_corrections(self, features: torch.Tensor, coherence_analysis: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Generate correction suggestions based on coherence analysis.
Args:
features: Input feature tensor
coherence_analysis: Coherence analysis results
Returns:
corrections: Suggested corrections to improve coherence
"""
# Weight corrections by coherence deficiency
coherence_weight = 1.0 - coherence_analysis['mean_coherence']
# Generate corrections
corrections = self.correction_head(features)
# Apply coherence-weighted corrections
weighted_corrections = corrections * coherence_weight
return weighted_corrections
def reflect(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Perform self-reflection on the current features.
Args:
features: Input feature tensor
Returns:
reflection_result: Dictionary containing analysis and corrections
"""
# Multi-layer reflection
reflected_features = features
for layer in self.reflection_layers:
reflected_features = layer(reflected_features) + reflected_features # Residual connection
# Analyze coherence
coherence_analysis = self.analyze_coherence(reflected_features)
# Generate corrections
corrections = self.generate_corrections(reflected_features, coherence_analysis)
return {
'reflected_features': reflected_features,
'coherence_analysis': coherence_analysis,
'corrections': corrections,
'original_features': features
}
def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward pass performing self-reflection.
Args:
features: Input feature tensor
Returns:
reflection_result: Self-reflection results
"""
return self.reflect(features)
class MultiPassRefinement(nn.Module):
"""
Implements multi-pass refinement mechanism for iterative improvement.
This module performs multiple passes of generation and refinement,
using self-reflection to guide the improvement process.
"""
def __init__(self, feature_dim: int, max_passes: int = 3, coherence_threshold: float = 0.8):
super().__init__()
self.feature_dim = feature_dim
self.max_passes = max_passes
self.coherence_threshold = coherence_threshold
# Self-reflection module
self.reflection_module = SelfReflectionModule(feature_dim)
# Refinement network
num_groups = min(8, feature_dim) if feature_dim >= 8 else 1
self.refinement_net = nn.Sequential(
nn.Conv2d(feature_dim * 2, feature_dim, 3, padding=1), # features + corrections
nn.GroupNorm(num_groups, feature_dim),
nn.SiLU(),
nn.Conv2d(feature_dim, feature_dim, 3, padding=1),
nn.GroupNorm(num_groups, feature_dim),
nn.SiLU(),
nn.Conv2d(feature_dim, feature_dim, 3, padding=1)
)
# History tracking
self.coherence_history = deque(maxlen=max_passes)
def should_continue_refinement(self, coherence_score: float, pass_num: int) -> bool:
"""
Determine if refinement should continue.
Args:
coherence_score: Current coherence score
pass_num: Current pass number
Returns:
should_continue: Whether to continue refinement
"""
# Stop if coherence threshold is reached
if coherence_score >= self.coherence_threshold:
return False
# Stop if maximum passes reached
if pass_num >= self.max_passes:
return False
# Stop if coherence is not improving
if len(self.coherence_history) >= 2:
recent_improvement = self.coherence_history[-1] - self.coherence_history[-2]
if recent_improvement < 0.01: # Minimal improvement threshold
return False
return True
def refine_features(self, features: torch.Tensor, corrections: torch.Tensor) -> torch.Tensor:
"""
Apply refinement to features using corrections.
Args:
features: Input features
corrections: Correction suggestions
Returns:
refined_features: Refined feature tensor
"""
# Concatenate features and corrections
combined = torch.cat([features, corrections], dim=1)
# Apply refinement network
refinement = self.refinement_net(combined)
# Apply refinement with residual connection
refined_features = features + refinement
return refined_features
def forward(self, initial_features: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Perform multi-pass refinement.
Args:
initial_features: Initial feature tensor
Returns:
refinement_result: Dictionary containing refinement results
"""
current_features = initial_features
pass_num = 0
refinement_history = []
# Clear history for new refinement session
self.coherence_history.clear()
while True:
# Perform self-reflection
reflection_result = self.reflection_module(current_features)
# Extract coherence score
coherence_score = reflection_result['coherence_analysis']['mean_coherence'].mean().item()
self.coherence_history.append(coherence_score)
# Store history
refinement_history.append({
'pass': pass_num,
'features': current_features.clone(),
'coherence_score': coherence_score,
'reflection_result': reflection_result
})
# Check if refinement should continue
if not self.should_continue_refinement(coherence_score, pass_num):
break
# Apply refinement
corrections = reflection_result['corrections']
current_features = self.refine_features(current_features, corrections)
pass_num += 1
return {
'final_features': current_features,
'initial_features': initial_features,
'refinement_history': refinement_history,
'total_passes': pass_num + 1,
'final_coherence': coherence_score
}
def test_coherence_monitoring():
"""Test function for coherence monitoring components."""
print("Testing Coherence Monitoring and Self-Reflection...")
# Create test features
batch_size, channels, height, width = 2, 64, 32, 32
test_features = torch.randn(batch_size, channels, height, width)
# Test coherence metrics
semantic_coh = CoherenceMetrics.semantic_coherence(test_features)
structural_coh = CoherenceMetrics.structural_coherence(test_features)
print(f"Semantic coherence shape: {semantic_coh.shape}")
print(f"Structural coherence shape: {structural_coh.shape}")
print(f"Semantic coherence mean: {semantic_coh.mean().item():.4f}")
print(f"Structural coherence mean: {structural_coh.mean().item():.4f}")
# Test temporal coherence
feature_sequence = [torch.randn(batch_size, channels, height, width) for _ in range(5)]
temporal_coh = CoherenceMetrics.temporal_coherence(feature_sequence)
print(f"Temporal coherence shape: {temporal_coh.shape}")
print(f"Temporal coherence mean: {temporal_coh.mean().item():.4f}")
# Test self-reflection module
reflection_module = SelfReflectionModule(channels)
reflection_result = reflection_module(test_features)
print(f"Reflected features shape: {reflection_result['reflected_features'].shape}")
print(f"Corrections shape: {reflection_result['corrections'].shape}")
print(f"Overall coherence mean: {reflection_result['coherence_analysis']['overall_coherence'].mean().item():.4f}")
# Test multi-pass refinement
refinement_module = MultiPassRefinement(channels, max_passes=3, coherence_threshold=0.9)
refinement_result = refinement_module(test_features)
print(f"Final features shape: {refinement_result['final_features'].shape}")
print(f"Total passes: {refinement_result['total_passes']}")
print(f"Final coherence: {refinement_result['final_coherence']:.4f}")
print(f"Refinement history length: {len(refinement_result['refinement_history'])}")
print("All coherence monitoring tests passed!")
if __name__ == "__main__":
test_coherence_monitoring()