""" Tensor operations for distributed computing. """ import torch import numpy as np from typing import Dict, List, Optional, Union, Tuple class TensorOps: """Utility class for distributed tensor operations.""" @staticmethod def split_tensor(tensor: torch.Tensor, num_parts: int) -> List[torch.Tensor]: """Split a tensor into multiple parts for distributed processing.""" return torch.chunk(tensor, num_parts) @staticmethod def merge_tensors(tensors: List[torch.Tensor], dim: int = 0) -> torch.Tensor: """Merge multiple tensors back into a single tensor.""" return torch.cat(tensors, dim=dim) @staticmethod def average_gradients(gradients: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: """Average gradients from multiple workers.""" avg_gradients = {} for key in gradients[0].keys(): avg_gradients[key] = torch.mean(torch.stack([g[key] for g in gradients]), dim=0) return avg_gradients @staticmethod def serialize_tensor(tensor: torch.Tensor) -> Dict[str, Union[List, str]]: """Serialize a tensor for storage/transmission.""" return { 'data': tensor.cpu().numpy().tolist(), 'shape': list(tensor.shape), 'dtype': str(tensor.dtype) } @staticmethod def deserialize_tensor(tensor_dict: Dict[str, Union[List, str]]) -> torch.Tensor: """Deserialize a tensor from storage/transmission format.""" data = np.array(tensor_dict['data']) shape = tensor_dict['shape'] dtype = getattr(torch, tensor_dict['dtype'].split('.')[-1]) return torch.tensor(data, dtype=dtype).reshape(shape) @staticmethod def gradient_clipping(gradients: Dict[str, torch.Tensor], max_norm: float) -> Dict[str, torch.Tensor]: """Apply gradient clipping to prevent exploding gradients.""" for k, v in gradients.items(): if v is not None: torch.nn.utils.clip_grad_norm_(v, max_norm) return gradients @staticmethod def reduce_precision(tensor: torch.Tensor, bits: int = 16) -> torch.Tensor: """Reduce tensor precision for efficient transmission.""" if bits == 16: return tensor.half() elif bits == 32: return tensor.float() else: raise ValueError("Unsupported precision bits") @staticmethod def shard_tensor(tensor: torch.Tensor, shard_size: int) -> List[torch.Tensor]: """Shard a tensor into smaller pieces for distributed processing.""" return [tensor[i:i + shard_size] for i in range(0, tensor.size(0), shard_size)] @staticmethod def compute_parameter_norm(parameters: Dict[str, torch.Tensor]) -> float: """Compute the total norm of all parameters.""" total_norm = 0.0 for param in parameters.values(): total_norm += param.norm().item() ** 2 return total_norm ** 0.5