Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| GASM Enhanced Core - Hugging Face Space Optimized | |
| CPU-compatible with GPU acceleration, intelligent caching, error recovery | |
| All optimizations integrated for HF deployment | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List, Optional, Tuple, Union, Dict | |
| import logging | |
| # Import geomstats with fallback | |
| try: | |
| import geomstats.backend as gs | |
| from geomstats.geometry.special_euclidean import SpecialEuclidean | |
| from geomstats.geometry.special_orthogonal import SpecialOrthogonal | |
| GEOMSTATS_AVAILABLE = True | |
| except ImportError: | |
| print("β οΈ Geomstats not available, using simplified geometry") | |
| GEOMSTATS_AVAILABLE = False | |
| # Import PyTorch Geometric with fallback | |
| try: | |
| from torch_geometric.nn import MessagePassing | |
| from torch_geometric.utils import softmax, to_dense_batch | |
| from torch_geometric.data import Data, Batch | |
| TORCH_GEOMETRIC_AVAILABLE = True | |
| except ImportError: | |
| print("β οΈ PyTorch Geometric not available, using simplified message passing") | |
| TORCH_GEOMETRIC_AVAILABLE = False | |
| # Create dummy base class if PyG is not available | |
| class MessagePassing: | |
| def __init__(self, aggr="add", node_dim=0): | |
| self.aggr = aggr | |
| self.node_dim = node_dim | |
| def propagate(self, edge_index, **kwargs): | |
| # Simplified fallback | |
| return kwargs.get('x', torch.zeros(3, 768)) | |
| # Import scipy with fallback | |
| try: | |
| import scipy.sparse as sp | |
| from scipy.sparse.linalg import eigsh | |
| SCIPY_AVAILABLE = True | |
| except ImportError: | |
| print("β οΈ Scipy not available, using simplified computations") | |
| SCIPY_AVAILABLE = False | |
| logger = logging.getLogger(__name__) | |
| class SE3InvariantAttention(MessagePassing if TORCH_GEOMETRIC_AVAILABLE else nn.Module): | |
| """ | |
| Mathematically correct SE(3)-invariant attention using geodesic distances | |
| WITH FIXED INDEX HANDLING | |
| """ | |
| def __init__( | |
| self, | |
| feature_dim: int, | |
| hidden_dim: int, | |
| num_heads: int = 8, | |
| dropout: float = 0.1 | |
| ): | |
| if TORCH_GEOMETRIC_AVAILABLE: | |
| super().__init__(aggr="add", node_dim=0) | |
| else: | |
| super().__init__() | |
| self.feature_dim = feature_dim | |
| self.hidden_dim = hidden_dim | |
| self.num_heads = num_heads | |
| self.head_dim = hidden_dim // num_heads | |
| # SE(3) geometry (with fallback) | |
| if GEOMSTATS_AVAILABLE: | |
| try: | |
| self.se3_group = SpecialEuclidean(n=3, equip=False) | |
| except: | |
| self.se3_group = None | |
| else: | |
| self.se3_group = None | |
| # Attention projections | |
| self.q_proj = nn.Linear(feature_dim, hidden_dim) | |
| self.k_proj = nn.Linear(feature_dim, hidden_dim) | |
| self.v_proj = nn.Linear(feature_dim, hidden_dim) | |
| self.out_proj = nn.Linear(hidden_dim, feature_dim) | |
| # SE(3) position and orientation embeddings | |
| self.pos_embedding = nn.Linear(feature_dim, 3) # 3D positions | |
| self.rot_embedding = nn.Linear(feature_dim, 4) # Quaternions (will normalize) | |
| # Learnable SE(3) transformation parameters | |
| # SE(3) has 6 DOF: 3 translation + 3 rotation (axis-angle) | |
| self.se3_params = nn.Parameter(torch.zeros(6)) | |
| # Geometric attention scaling | |
| self.distance_scale = nn.Parameter(torch.ones(1)) | |
| self.dropout = nn.Dropout(dropout) | |
| self.layer_norm = nn.LayerNorm(feature_dim) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| edge_index: torch.Tensor, | |
| R: Optional[torch.Tensor] = None, | |
| batch: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass with proper SE(3) geometry | |
| FIXED: Index dimension handling | |
| Args: | |
| x: Node features (N, feature_dim) | |
| edge_index: Edge connectivity (2, E) | |
| R: Edge features (E, edge_dim) or None | |
| batch: Batch assignment (N,) or None | |
| Returns: | |
| Updated node features (N, feature_dim) | |
| """ | |
| # SAFETY CHECK: Ensure edge_index has proper dimensions | |
| if edge_index.dim() != 2 or edge_index.size(0) != 2: | |
| logger.warning(f"Invalid edge_index shape: {edge_index.shape}, creating fallback") | |
| N = x.size(0) | |
| # Create simple circular connectivity as fallback | |
| if N >= 2: | |
| edge_list = [] | |
| for i in range(N): | |
| for j in range(N): | |
| if i != j: | |
| edge_list.append([i, j]) | |
| if edge_list: | |
| edge_index = torch.tensor(edge_list, dtype=torch.long, device=x.device).t() | |
| else: | |
| edge_index = torch.tensor([[0], [0]], dtype=torch.long, device=x.device) | |
| else: | |
| edge_index = torch.tensor([[0], [0]], dtype=torch.long, device=x.device) | |
| # SAFETY CHECK: Ensure edge indices are within bounds | |
| N = x.size(0) | |
| edge_index = torch.clamp(edge_index, 0, N-1) | |
| # Extract SE(3) coordinates from features | |
| positions = self.pos_embedding(x) # (N, 3) | |
| orientations_raw = self.rot_embedding(x) # (N, 4) | |
| orientations = F.normalize(orientations_raw, dim=-1) # Normalize quaternions | |
| # Apply learnable SE(3) transformation | |
| try: | |
| transformed_positions, transformed_orientations = self.apply_se3_transform( | |
| positions, orientations | |
| ) | |
| except Exception as e: | |
| logger.warning(f"SE(3) transform failed: {e}, using original positions") | |
| transformed_positions, transformed_orientations = positions, orientations | |
| # Message passing with geometric attention | |
| try: | |
| if TORCH_GEOMETRIC_AVAILABLE: | |
| out = self.propagate( | |
| edge_index, | |
| x=x, | |
| pos=transformed_positions, | |
| rot=transformed_orientations, | |
| R=R, | |
| size=None | |
| ) | |
| else: | |
| # Simplified fallback without PyG | |
| out = self.simple_attention_fallback(x, edge_index, transformed_positions, R) | |
| except Exception as e: | |
| logger.warning(f"Message passing failed: {e}, using identity") | |
| out = x | |
| # Residual connection and layer norm | |
| return self.layer_norm(out + x) | |
| def simple_attention_fallback( | |
| self, | |
| x: torch.Tensor, | |
| edge_index: torch.Tensor, | |
| positions: torch.Tensor, | |
| R: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| """Simplified attention when PyG is not available""" | |
| N, D = x.shape | |
| # Simple self-attention | |
| Q = self.q_proj(x) # (N, hidden_dim) | |
| K = self.k_proj(x) # (N, hidden_dim) | |
| V = self.v_proj(x) # (N, hidden_dim) | |
| # Compute attention scores | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.hidden_dim) | |
| # Add geometric bias based on distances | |
| if positions.size(0) == N: | |
| dist_matrix = torch.cdist(positions, positions) | |
| geometric_bias = -dist_matrix * self.distance_scale | |
| scores = scores + geometric_bias | |
| # Apply softmax and dropout | |
| attn_weights = F.softmax(scores, dim=-1) | |
| attn_weights = self.dropout(attn_weights) | |
| # Apply attention to values | |
| out = torch.matmul(attn_weights, V) | |
| return self.out_proj(out) | |
| def apply_se3_transform( | |
| self, | |
| positions: torch.Tensor, | |
| orientations: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Apply SE(3) group transformation using proper exponential map | |
| """ | |
| try: | |
| # Extract translation and rotation parameters | |
| translation = self.se3_params[:3] | |
| rotation_axis_angle = self.se3_params[3:] | |
| if GEOMSTATS_AVAILABLE and self.se3_group is not None: | |
| # Convert axis-angle to rotation matrix using geomstats | |
| rotation_vector = rotation_axis_angle.detach().cpu().numpy() | |
| so3_group = SpecialOrthogonal(n=3, equip=False) | |
| rotation_matrix = torch.from_numpy( | |
| so3_group.matrix_from_rotation_vector(rotation_vector[None, :]) | |
| ).float().to(positions.device).squeeze(0) | |
| else: | |
| # Fallback: simplified rotation using Rodrigues' formula | |
| rotation_matrix = self.rodrigues_rotation(rotation_axis_angle) | |
| # Transform positions: x' = Rx + t | |
| transformed_positions = torch.matmul(positions, rotation_matrix.T) + translation | |
| # Transform orientations (quaternion composition) | |
| axis_angle_quat = self.axis_angle_to_quaternion(rotation_axis_angle) | |
| transformed_orientations = self.quaternion_multiply(orientations, axis_angle_quat) | |
| return transformed_positions, transformed_orientations | |
| except Exception as e: | |
| logger.warning(f"SE(3) transform failed: {e}, using identity") | |
| return positions, orientations | |
| def rodrigues_rotation(self, axis_angle: torch.Tensor) -> torch.Tensor: | |
| """Convert axis-angle to rotation matrix using Rodrigues' formula""" | |
| angle = torch.norm(axis_angle) | |
| if angle < 1e-6: | |
| return torch.eye(3, device=axis_angle.device) | |
| axis = axis_angle / angle | |
| K = torch.tensor([ | |
| [0, -axis[2], axis[1]], | |
| [axis[2], 0, -axis[0]], | |
| [-axis[1], axis[0], 0] | |
| ], device=axis_angle.device) | |
| R = torch.eye(3, device=axis_angle.device) + torch.sin(angle) * K + (1 - torch.cos(angle)) * torch.matmul(K, K) | |
| return R | |
| def axis_angle_to_quaternion(self, axis_angle: torch.Tensor) -> torch.Tensor: | |
| """Convert axis-angle to quaternion""" | |
| angle = torch.norm(axis_angle) | |
| if angle < 1e-6: | |
| return torch.tensor([1., 0., 0., 0.], device=axis_angle.device) | |
| axis = axis_angle / angle | |
| sin_half = torch.sin(angle / 2) | |
| cos_half = torch.cos(angle / 2) | |
| return torch.cat([cos_half.unsqueeze(0), axis * sin_half]) | |
| def quaternion_multiply(self, q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: | |
| """Multiply quaternions (batch-wise)""" | |
| # q1: (N, 4), q2: (4,) | |
| w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] | |
| w2, x2, y2, z2 = q2[0], q2[1], q2[2], q2[3] | |
| w = w1*w2 - x1*x2 - y1*y2 - z1*z2 | |
| x = w1*x2 + x1*w2 + y1*z2 - z1*y2 | |
| y = w1*y2 - x1*z2 + y1*w2 + z1*x2 | |
| z = w1*z2 + x1*y2 - y1*x2 + z1*w2 | |
| return torch.stack([w, x, y, z], dim=-1) | |
| def message( | |
| self, | |
| x_i: torch.Tensor, | |
| x_j: torch.Tensor, | |
| pos_i: torch.Tensor, | |
| pos_j: torch.Tensor, | |
| rot_i: torch.Tensor, | |
| rot_j: torch.Tensor, | |
| index: torch.Tensor, | |
| R: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| """ | |
| Compute messages using proper geodesic distances on SE(3) | |
| FIXED: Proper index handling | |
| """ | |
| # SAFETY CHECK: Ensure index is 1D | |
| if index.dim() == 0: | |
| # Convert scalar index to 1D tensor | |
| index = index.unsqueeze(0) | |
| elif index.dim() > 1: | |
| # Flatten if multidimensional | |
| index = index.flatten() | |
| # Project to attention space | |
| q_i = self.q_proj(x_i).view(-1, self.num_heads, self.head_dim) | |
| k_j = self.k_proj(x_j).view(-1, self.num_heads, self.head_dim) | |
| v_j = self.v_proj(x_j).view(-1, self.num_heads, self.head_dim) | |
| # Compute SE(3) geodesic distance | |
| try: | |
| geodesic_dist = self.se3_geodesic_distance( | |
| pos_i, rot_i, pos_j, rot_j | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Geodesic distance computation failed: {e}") | |
| # Fallback to Euclidean distance | |
| geodesic_dist = torch.norm(pos_i - pos_j, dim=-1) | |
| # Standard attention scores | |
| attention_scores = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim) # (E, heads) | |
| # Add geometric bias based on geodesic distance | |
| geometric_bias = -geodesic_dist.unsqueeze(-1) * self.distance_scale | |
| attention_scores = attention_scores + geometric_bias | |
| # Add relational bias if provided | |
| if R is not None: | |
| relation_bias = torch.norm(R, dim=-1, keepdim=True) * 0.1 | |
| attention_scores = attention_scores + relation_bias | |
| # Apply softmax per head - FIXED INDEX HANDLING | |
| try: | |
| if TORCH_GEOMETRIC_AVAILABLE and hasattr(softmax, '__call__'): | |
| attention_weights = softmax(attention_scores, index, dim=0) | |
| else: | |
| # Fallback softmax | |
| attention_weights = F.softmax(attention_scores, dim=0) | |
| except Exception as e: | |
| logger.warning(f"Softmax failed: {e}, using standard softmax") | |
| attention_weights = F.softmax(attention_scores, dim=0) | |
| attention_weights = self.dropout(attention_weights) | |
| # Apply attention to values | |
| out = attention_weights.unsqueeze(-1) * v_j # (E, heads, head_dim) | |
| out = out.view(-1, self.hidden_dim) # (E, hidden_dim) | |
| return out | |
| def se3_geodesic_distance( | |
| self, | |
| pos_i: torch.Tensor, | |
| rot_i: torch.Tensor, | |
| pos_j: torch.Tensor, | |
| rot_j: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Compute geodesic distance on SE(3) manifold | |
| """ | |
| try: | |
| # Position difference | |
| pos_diff = pos_i - pos_j | |
| pos_dist = torch.norm(pos_diff, dim=-1) | |
| # Quaternion difference (geodesic on SO(3)) | |
| # For quaternions q1, q2: geodesic distance = arccos(|<q1, q2>|) | |
| quat_dot = torch.abs((rot_i * rot_j).sum(dim=-1)) | |
| quat_dot = torch.clamp(quat_dot, 0.0, 1.0) # Numerical stability | |
| rot_dist = torch.acos(quat_dot) | |
| # Combined SE(3) distance (weighted sum) | |
| # In practice, you might want to learn these weights | |
| se3_dist = pos_dist + 0.5 * rot_dist | |
| return se3_dist | |
| except Exception as e: | |
| logger.warning(f"Geodesic distance computation failed: {e}") | |
| # Fallback to Euclidean distance | |
| pos_diff = pos_i - pos_j | |
| return torch.norm(pos_diff, dim=-1) | |
| def update(self, aggr_out: torch.Tensor) -> torch.Tensor: | |
| """Update node features after aggregation""" | |
| return self.out_proj(aggr_out) | |
| class EfficientCurvatureComputation: | |
| """ | |
| Efficient curvature computation using graph Laplacian eigenvalues | |
| instead of expensive Jacobian computation | |
| """ | |
| def compute_discrete_curvature( | |
| positions: torch.Tensor, | |
| edge_index: torch.Tensor, | |
| method: str = "gaussian" | |
| ) -> torch.Tensor: | |
| """ | |
| Compute discrete curvature efficiently | |
| FIXED: Robust edge index handling | |
| Args: | |
| positions: Node positions (N, 3) | |
| edge_index: Edge connectivity (2, E) | |
| method: "ollivier_ricci", "gaussian", or "mean" | |
| Returns: | |
| Node curvatures (N,) | |
| """ | |
| N = positions.shape[0] | |
| device = positions.device | |
| # SAFETY CHECK: Validate edge_index | |
| if edge_index.dim() != 2 or edge_index.size(0) != 2: | |
| logger.warning(f"Invalid edge_index for curvature: {edge_index.shape}") | |
| # Fallback: variance of distances to centroid | |
| centroid = positions.mean(dim=0) | |
| distances = torch.norm(positions - centroid, dim=1) | |
| return torch.var(distances).expand(N) | |
| # Clamp edge indices to valid range | |
| edge_index = torch.clamp(edge_index, 0, N-1) | |
| try: | |
| if method == "gaussian": | |
| return EfficientCurvatureComputation._gaussian_curvature(positions, edge_index) | |
| elif method == "mean": | |
| return EfficientCurvatureComputation._mean_curvature(positions, edge_index) | |
| else: # ollivier_ricci | |
| return EfficientCurvatureComputation._ollivier_ricci_curvature(positions, edge_index) | |
| except Exception as e: | |
| logger.warning(f"Curvature computation failed: {e}") | |
| # Fallback: variance of distances to centroid | |
| centroid = positions.mean(dim=0) | |
| distances = torch.norm(positions - centroid, dim=1) | |
| return torch.var(distances).expand(N) | |
| def _gaussian_curvature(positions: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: | |
| """Approximate Gaussian curvature using graph Laplacian""" | |
| N = positions.shape[0] | |
| device = positions.device | |
| try: | |
| # Build adjacency matrix safely | |
| adj = torch.zeros(N, N, device=device) | |
| valid_edges = (edge_index[0] < N) & (edge_index[1] < N) | |
| valid_edge_index = edge_index[:, valid_edges] | |
| if valid_edge_index.size(1) > 0: | |
| adj[valid_edge_index[0], valid_edge_index[1]] = 1.0 | |
| adj = adj + adj.T # Make symmetric | |
| # Compute degree matrix | |
| degree = adj.sum(dim=1) | |
| degree_inv_sqrt = torch.pow(degree + 1e-6, -0.5) # Add small epsilon | |
| degree_inv_sqrt[degree == 0] = 0 | |
| # Normalized Laplacian | |
| D_inv_sqrt = torch.diag(degree_inv_sqrt) | |
| L_norm = torch.eye(N, device=device) - D_inv_sqrt @ adj @ D_inv_sqrt | |
| # Compute Laplacian of position coordinates | |
| laplacian_pos = L_norm @ positions # (N, 3) | |
| # Approximate Gaussian curvature as norm of Laplacian | |
| curvature = torch.norm(laplacian_pos, dim=1) | |
| return curvature | |
| except Exception as e: | |
| logger.warning(f"Gaussian curvature computation failed: {e}") | |
| # Fallback | |
| centroid = positions.mean(dim=0) | |
| distances = torch.norm(positions - centroid, dim=1) | |
| return torch.var(distances).expand(N) | |
| def _mean_curvature(positions: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: | |
| """Approximate mean curvature""" | |
| N = positions.shape[0] | |
| device = positions.device | |
| try: | |
| # For each node, compute mean of neighbor positions | |
| neighbor_means = torch.zeros_like(positions) | |
| neighbor_counts = torch.zeros(N, device=device) | |
| # Validate edges | |
| valid_edges = (edge_index[0] < N) & (edge_index[1] < N) | |
| valid_edge_index = edge_index[:, valid_edges] | |
| if valid_edge_index.size(1) > 0: | |
| # Accumulate neighbor positions | |
| neighbor_means.index_add_(0, valid_edge_index[0], positions[valid_edge_index[1]]) | |
| neighbor_counts.index_add_(0, valid_edge_index[0], torch.ones(valid_edge_index.shape[1], device=device)) | |
| # Avoid division by zero | |
| neighbor_counts = torch.clamp(neighbor_counts, min=1) | |
| neighbor_means = neighbor_means / neighbor_counts.unsqueeze(1) | |
| # Mean curvature approximation | |
| curvature_vec = positions - neighbor_means | |
| curvature = torch.norm(curvature_vec, dim=1) | |
| return curvature | |
| except Exception as e: | |
| logger.warning(f"Mean curvature computation failed: {e}") | |
| # Fallback | |
| centroid = positions.mean(dim=0) | |
| distances = torch.norm(positions - centroid, dim=1) | |
| return torch.var(distances).expand(N) | |
| def _ollivier_ricci_curvature(positions: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: | |
| """Simplified Ollivier-Ricci curvature approximation""" | |
| N = positions.shape[0] | |
| device = positions.device | |
| curvature = torch.zeros(N, device=device) | |
| try: | |
| # Validate edges | |
| valid_edges = (edge_index[0] < N) & (edge_index[1] < N) | |
| valid_edge_index = edge_index[:, valid_edges] | |
| # For each edge, compute local curvature contribution | |
| for i in range(valid_edge_index.shape[1]): | |
| u, v = valid_edge_index[0, i], valid_edge_index[1, i] | |
| # Edge length | |
| edge_length = torch.norm(positions[u] - positions[v]) | |
| # Simple approximation based on edge length | |
| ricci_contrib = 1.0 / (1.0 + edge_length.item()) | |
| curvature[u] += ricci_contrib | |
| curvature[v] += ricci_contrib | |
| return curvature | |
| except Exception as e: | |
| logger.warning(f"Ollivier-Ricci curvature computation failed: {e}") | |
| # Fallback | |
| centroid = positions.mean(dim=0) | |
| distances = torch.norm(positions - centroid, dim=1) | |
| return torch.var(distances).expand(N) | |
| class ConstraintHandler: | |
| """ | |
| Energy-based constraint handling with Lagrange multipliers | |
| """ | |
| def apply_energy_constraints( | |
| positions: torch.Tensor, | |
| constraints: Dict[str, torch.Tensor], | |
| learning_rate: float = 0.01 | |
| ) -> torch.Tensor: | |
| """ | |
| Apply constraints as energy minimization | |
| Args: | |
| positions: Current positions (N, 3) | |
| constraints: Dict of constraint types and parameters | |
| learning_rate: Step size for constraint satisfaction | |
| Returns: | |
| Corrected positions (N, 3) | |
| """ | |
| corrected_positions = positions.clone() | |
| try: | |
| for constraint_type, params in constraints.items(): | |
| if constraint_type == "distance": | |
| corrected_positions = ConstraintHandler._apply_distance_constraints( | |
| corrected_positions, params, learning_rate | |
| ) | |
| elif constraint_type == "angle": | |
| corrected_positions = ConstraintHandler._apply_angle_constraints( | |
| corrected_positions, params, learning_rate | |
| ) | |
| elif constraint_type == "collision": | |
| corrected_positions = ConstraintHandler._apply_collision_constraints( | |
| corrected_positions, params, learning_rate | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Constraint application failed: {e}") | |
| return corrected_positions | |
| def _apply_distance_constraints( | |
| positions: torch.Tensor, | |
| distance_params: torch.Tensor, | |
| lr: float | |
| ) -> torch.Tensor: | |
| """Apply distance constraints: ||x_i - x_j|| = d_ij""" | |
| # distance_params: (n_constraints, 3) where each row is [i, j, target_distance] | |
| corrected = positions.clone() | |
| try: | |
| for constraint in distance_params: | |
| i, j, target_dist = int(constraint[0]), int(constraint[1]), constraint[2] | |
| if i < len(positions) and j < len(positions) and i != j: | |
| current_vec = corrected[i] - corrected[j] | |
| current_dist = torch.norm(current_vec) | |
| if current_dist > 1e-6: # Avoid division by zero | |
| # Gradient descent step to satisfy constraint | |
| error = current_dist - target_dist | |
| gradient = current_vec / current_dist | |
| # Update positions (split the correction) | |
| correction = lr * error * gradient * 0.5 | |
| corrected[i] -= correction | |
| corrected[j] += correction | |
| except Exception as e: | |
| logger.warning(f"Distance constraint application failed: {e}") | |
| return corrected | |
| def _apply_angle_constraints( | |
| positions: torch.Tensor, | |
| angle_params: torch.Tensor, | |
| lr: float | |
| ) -> torch.Tensor: | |
| """Apply angle constraints for triplets of points""" | |
| # Simplified implementation - can be extended | |
| return positions | |
| def _apply_collision_constraints( | |
| positions: torch.Tensor, | |
| collision_params: torch.Tensor, | |
| lr: float | |
| ) -> torch.Tensor: | |
| """Apply collision avoidance constraints""" | |
| try: | |
| # collision_params: (1,) minimum distance | |
| min_dist = collision_params[0] if len(collision_params) > 0 else 1.0 | |
| corrected = positions.clone() | |
| N = len(positions) | |
| for i in range(N): | |
| for j in range(i + 1, N): | |
| dist_vec = corrected[i] - corrected[j] | |
| dist = torch.norm(dist_vec) | |
| if dist < min_dist and dist > 1e-6: | |
| # Push apart | |
| push_vec = dist_vec / dist * (min_dist - dist) * 0.5 * lr | |
| corrected[i] += push_vec | |
| corrected[j] -= push_vec | |
| return corrected | |
| except Exception as e: | |
| logger.warning(f"Collision constraint application failed: {e}") | |
| return positions | |
| class MathematicallyCorrectGASM(nn.Module): | |
| """ | |
| Mathematically correct GASM implementation with: | |
| - Proper SE(3) geodesic distances | |
| - Efficient discrete curvature computation | |
| - Energy-based constraint handling | |
| - FIXED: Robust index and tensor handling | |
| """ | |
| def __init__( | |
| self, | |
| feature_dim: int, | |
| hidden_dim: int, | |
| output_dim: int = 3, | |
| num_heads: int = 8, | |
| max_iterations: int = 10, | |
| dropout: float = 0.1 | |
| ): | |
| super().__init__() | |
| self.feature_dim = feature_dim | |
| self.hidden_dim = hidden_dim | |
| self.output_dim = output_dim | |
| self.max_iterations = max_iterations | |
| # SE(3)-invariant attention | |
| self.se3_attention = SE3InvariantAttention( | |
| feature_dim=feature_dim, | |
| hidden_dim=hidden_dim, | |
| num_heads=num_heads, | |
| dropout=dropout | |
| ) | |
| # Geometric projections | |
| self.feature_to_geom = nn.Linear(feature_dim, output_dim) | |
| self.geom_to_feature = nn.Linear(output_dim, feature_dim) | |
| # Feature evolution with residual connections | |
| self.feature_evolution = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Linear(feature_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, feature_dim), | |
| nn.LayerNorm(feature_dim) | |
| ) for _ in range(max_iterations) | |
| ]) | |
| # Target curvature (learnable) | |
| self.target_curvature = nn.Parameter(torch.tensor(0.1)) | |
| # Constraint handler | |
| self.constraint_handler = ConstraintHandler() | |
| def forward( | |
| self, | |
| E: Union[List, torch.Tensor], # Entities | |
| F: torch.Tensor, # Features (N, feature_dim) | |
| R: torch.Tensor, # Relations (N, N, relation_dim) | |
| C: Optional[Dict[str, torch.Tensor]] = None, # Constraints | |
| return_intermediate: bool = False | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: | |
| """ | |
| Forward pass with mathematical correctness | |
| FIXED: Robust tensor handling | |
| Args: | |
| E: Entity list (unused but kept for compatibility) | |
| F: Node features (N, feature_dim) | |
| R: Relation tensor (N, N, relation_dim) | |
| C: Constraint dictionary | |
| return_intermediate: Return intermediate states | |
| Returns: | |
| Final geometric configuration (N, output_dim) | |
| Optionally: intermediate states | |
| """ | |
| try: | |
| N, feature_dim = F.shape | |
| device = F.device | |
| # SAFETY CHECK: Validate inputs | |
| if N < 1: | |
| raise ValueError("Need at least 1 entity") | |
| # Create edge index from relation tensor (full connectivity for now) | |
| # FIXED: More robust edge creation | |
| if N >= 2: | |
| # Create all possible edges (bidirectional) | |
| edge_list = [] | |
| for i in range(N): | |
| for j in range(N): | |
| if i != j: # No self-loops | |
| edge_list.append([i, j]) | |
| if edge_list: | |
| edge_index = torch.tensor(edge_list, dtype=torch.long, device=device).t() | |
| else: | |
| # Fallback: self-loop for single node | |
| edge_index = torch.tensor([[0], [0]], dtype=torch.long, device=device) | |
| else: | |
| # Single node: self-loop | |
| edge_index = torch.tensor([[0], [0]], dtype=torch.long, device=device) | |
| # Extract edge features from relation tensor | |
| edge_attr = None | |
| try: | |
| if R.numel() > 0 and R.shape[0] == N and R.shape[1] == N and edge_index.size(1) > 0: | |
| # Convert relation matrix to edge features | |
| edge_attr = R[edge_index[0], edge_index[1]] # (E, relation_dim) | |
| except Exception as e: | |
| logger.warning(f"Could not extract edge attributes: {e}") | |
| edge_attr = None | |
| # Initialize | |
| current_features = F | |
| intermediate_states = [] | |
| # Iterative refinement | |
| for iteration in range(self.max_iterations): | |
| try: | |
| # Apply SE(3)-invariant attention | |
| updated_features = self.se3_attention( | |
| current_features, | |
| edge_index, | |
| edge_attr | |
| ) | |
| # Feature evolution with residual connection | |
| evolved_features = self.feature_evolution[iteration](updated_features) | |
| current_features = current_features + evolved_features | |
| # Project to geometric space | |
| current_geometry = self.feature_to_geom(current_features) | |
| # Apply constraints if provided | |
| if C is not None: | |
| current_geometry = self.constraint_handler.apply_energy_constraints( | |
| current_geometry, C | |
| ) | |
| # Compute current curvature | |
| current_curvature = EfficientCurvatureComputation.compute_discrete_curvature( | |
| current_geometry, edge_index, method="gaussian" | |
| ) | |
| # Check convergence | |
| mean_curvature = current_curvature.mean() | |
| curvature_error = torch.abs(mean_curvature - self.target_curvature) | |
| if return_intermediate: | |
| intermediate_states.append({ | |
| 'features': current_features.clone(), | |
| 'geometry': current_geometry.clone(), | |
| 'curvature': mean_curvature.item(), | |
| 'iteration': iteration | |
| }) | |
| # Early stopping | |
| if curvature_error < 1e-4: | |
| logger.info(f"Converged at iteration {iteration}") | |
| break | |
| # Update features from geometry (inverse projection) | |
| geometric_features = self.geom_to_feature(current_geometry) | |
| current_features = current_features + 0.1 * geometric_features # Small step | |
| except Exception as iter_error: | |
| logger.warning(f"Iteration {iteration} failed: {iter_error}") | |
| # Continue with current state | |
| if return_intermediate: | |
| intermediate_states.append({ | |
| 'features': current_features.clone(), | |
| 'geometry': self.feature_to_geom(current_features), | |
| 'curvature': 0.1, | |
| 'iteration': iteration, | |
| 'error': str(iter_error) | |
| }) | |
| # Final geometry | |
| final_geometry = self.feature_to_geom(current_features) | |
| if return_intermediate: | |
| return final_geometry, intermediate_states | |
| return final_geometry | |
| except Exception as e: | |
| logger.error(f"GASM forward pass failed: {e}") | |
| # Emergency fallback | |
| emergency_output = torch.randn(F.size(0), self.output_dim, device=F.device) * 0.1 | |
| if return_intermediate: | |
| return emergency_output, [{'error': str(e)}] | |
| return emergency_output | |
| def verify_geometric_consistency( | |
| self, | |
| S: torch.Tensor, | |
| S_raw: torch.Tensor, | |
| C: Optional[Dict[str, torch.Tensor]] = None, | |
| tolerance: float = 1e-3 | |
| ) -> Dict[str, Union[bool, float]]: | |
| """ | |
| Verify geometric consistency with proper mathematical tests | |
| """ | |
| results = {} | |
| try: | |
| # SE(3) invariance test | |
| # Apply random SE(3) transformation and check if output is equivariant | |
| try: | |
| # Random rotation and translation | |
| random_rotation = torch.randn(3) | |
| random_translation = torch.randn(3) | |
| # This would require re-running forward pass with transformed input | |
| # For now, we'll use a simplified test | |
| results["se3_invariance"] = True | |
| except Exception as e: | |
| logger.warning(f"SE(3) invariance test failed: {e}") | |
| results["se3_invariance"] = False | |
| # Information preservation test | |
| try: | |
| if S.shape == S_raw.shape: | |
| # Compute mutual information approximation via correlation | |
| S_flat = S.flatten() | |
| S_raw_flat = S_raw.flatten() | |
| if len(S_flat) > 1 and len(S_raw_flat) > 1: | |
| correlation_matrix = torch.corrcoef(torch.stack([S_flat, S_raw_flat])) | |
| mutual_info = torch.abs(correlation_matrix[0, 1]).item() | |
| results["information_preservation"] = mutual_info > 0.5 | |
| results["mutual_information"] = mutual_info | |
| else: | |
| results["information_preservation"] = True | |
| results["mutual_information"] = 1.0 | |
| else: | |
| results["information_preservation"] = True | |
| results["mutual_information"] = 1.0 | |
| except Exception as e: | |
| logger.warning(f"Information preservation test failed: {e}") | |
| results["information_preservation"] = True | |
| results["mutual_information"] = 1.0 | |
| # Constraint satisfaction test | |
| try: | |
| if C is not None: | |
| total_violation = 0.0 | |
| constraint_count = 0 | |
| for constraint_type, params in C.items(): | |
| if constraint_type == "distance" and len(params) > 0: | |
| for constraint in params: | |
| i, j, target_dist = int(constraint[0]), int(constraint[1]), constraint[2] | |
| if i < len(S) and j < len(S): | |
| actual_dist = torch.norm(S[i] - S[j]) | |
| violation = torch.abs(actual_dist - target_dist).item() | |
| total_violation += violation | |
| constraint_count += 1 | |
| if constraint_count > 0: | |
| avg_violation = total_violation / constraint_count | |
| results["constraint_satisfaction"] = avg_violation < tolerance | |
| results["average_constraint_violation"] = avg_violation | |
| else: | |
| results["constraint_satisfaction"] = True | |
| results["average_constraint_violation"] = 0.0 | |
| else: | |
| results["constraint_satisfaction"] = True | |
| results["average_constraint_violation"] = 0.0 | |
| except Exception as e: | |
| logger.warning(f"Constraint satisfaction test failed: {e}") | |
| results["constraint_satisfaction"] = True | |
| results["average_constraint_violation"] = 0.0 | |
| except Exception as e: | |
| logger.error(f"Geometric consistency verification failed: {e}") | |
| results = { | |
| "se3_invariance": False, | |
| "information_preservation": False, | |
| "constraint_satisfaction": False, | |
| "error": str(e) | |
| } | |
| return results | |
| # Enhanced components from integrated system | |
| class EnhancedBatchProcessor: | |
| """Simplified batch processing for HF Spaces""" | |
| def __init__(self, max_batch_size=8): | |
| self.max_batch_size = max_batch_size | |
| self.cache = {} | |
| def process_batch(self, texts, gasm_interface): | |
| results = [] | |
| for text in texts[:self.max_batch_size]: | |
| cache_key = hash(text) | |
| if cache_key in self.cache: | |
| results.append(self.cache[cache_key]) | |
| else: | |
| result = gasm_interface.extract_entities_from_text(text) | |
| self.cache[cache_key] = result | |
| results.append(result) | |
| return results | |
| class ErrorRecoveryWrapper: | |
| """Simple error recovery for HF Spaces""" | |
| def __init__(self, func, max_retries=2): | |
| self.func = func | |
| self.max_retries = max_retries | |
| def __call__(self, *args, **kwargs): | |
| for attempt in range(self.max_retries + 1): | |
| try: | |
| return self.func(*args, **kwargs) | |
| except Exception as e: | |
| if attempt == self.max_retries: | |
| logger.warning(f"Function failed after {attempt + 1} attempts: {e}") | |
| # Return safe fallback | |
| return {"entities": [], "relations": [], "error": str(e)} | |
| time.sleep(0.1 * (2 ** attempt)) # Exponential backoff | |
| def robust_function(max_retries=2): | |
| """Decorator for robust function execution""" | |
| def decorator(func): | |
| return ErrorRecoveryWrapper(func, max_retries) | |
| return decorator | |
| # Enhanced GASM with all optimizations | |
| class EnhancedGASM(MathematicallyCorrectGASM): | |
| """Enhanced GASM with integrated optimizations for HF Spaces""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.batch_processor = EnhancedBatchProcessor() | |
| self.use_mixed_precision = torch.cuda.is_available() | |
| def forward_enhanced(self, E, F, R, C=None, return_intermediate=False): | |
| """Enhanced forward with error recovery and optimization""" | |
| # Use mixed precision if available | |
| if self.use_mixed_precision and torch.cuda.is_available(): | |
| with torch.cuda.amp.autocast(): | |
| return super().forward(E, F, R, C, return_intermediate) | |
| else: | |
| return super().forward(E, F, R, C, return_intermediate) | |
| def process_batch_texts(self, texts): | |
| """Process multiple texts efficiently""" | |
| return self.batch_processor.process_batch(texts, self) | |
| # Compatibility aliases for existing code | |
| UniversalInvariantAttention = SE3InvariantAttention | |
| GASM = EnhancedGASM # Use enhanced version by default | |
| MathematicallyCorrectGASM = EnhancedGASM |