|
|
|
|
|
""" |
|
|
Utility Functions |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
def generate_shifted_matrix(n, device=None, granularities=None): |
|
|
|
|
|
matrix_columns = [] |
|
|
if granularities is None: |
|
|
granularities = [2, 4] |
|
|
|
|
|
for granularity in granularities: |
|
|
if granularity > n: |
|
|
continue |
|
|
|
|
|
|
|
|
step_size = max(1, granularity // 2) |
|
|
max_start = n - granularity |
|
|
|
|
|
for start in range(0, max_start + 1, step_size): |
|
|
column = torch.zeros(n, dtype=torch.int, device=device) |
|
|
column[start:start + granularity] = 1 |
|
|
matrix_columns.append(column) |
|
|
|
|
|
|
|
|
if max_start >= 0 and (max_start % step_size) != 0: |
|
|
column = torch.zeros(n, dtype=torch.int, device=device) |
|
|
column[-granularity:] = 1 |
|
|
matrix_columns.append(column) |
|
|
|
|
|
if not matrix_columns: |
|
|
column = torch.ones(n, dtype=torch.int, device=device) |
|
|
matrix_columns.append(column) |
|
|
|
|
|
result = torch.stack(matrix_columns, dim=1).unsqueeze(0).expand(1, -1, -1) |
|
|
return result |
|
|
|
|
|
def create_attention_mask(shift_matrix: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Create attention mask from shift matrix |
|
|
|
|
|
Args: |
|
|
shift_matrix (torch.Tensor): shift matrix, shape [num_chunks, seq_len] |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: attention mask, shape [1, num_chunks, seq_len, seq_len] |
|
|
""" |
|
|
|
|
|
attention_mask = shift_matrix.transpose(0, 1) |
|
|
attention_mask = torch.where(attention_mask == 1.0, 0.0, float('-inf')) |
|
|
|
|
|
|
|
|
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
return attention_mask |
|
|
|
|
|
def normalize_embeddings(embeddings: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: |
|
|
""" |
|
|
L2 normalize embeddings |
|
|
|
|
|
Args: |
|
|
embeddings (torch.Tensor): Embeddings |
|
|
eps (float): Small value to prevent division by zero |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Normalized embeddings |
|
|
""" |
|
|
norm = torch.norm(embeddings, dim=-1, keepdim=True) |
|
|
return embeddings / (norm + eps) |
|
|
|
|
|
def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Calculate cosine similarity |
|
|
|
|
|
Args: |
|
|
a (torch.Tensor): Vector A |
|
|
b (torch.Tensor): Vector B |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Cosine similarity |
|
|
""" |
|
|
a_norm = normalize_embeddings(a) |
|
|
b_norm = normalize_embeddings(b) |
|
|
return torch.sum(a_norm * b_norm, dim=-1) |
|
|
|
|
|
def batch_cosine_similarity(embeddings1: torch.Tensor, embeddings2: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Calculate batch cosine similarity |
|
|
|
|
|
Args: |
|
|
embeddings1 (torch.Tensor): Embeddings group 1, shape [N, dim] |
|
|
embeddings2 (torch.Tensor): Embeddings group 2, shape [M, dim] |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Similarity matrix, shape [N, M] |
|
|
""" |
|
|
embeddings1_norm = normalize_embeddings(embeddings1) |
|
|
embeddings2_norm = normalize_embeddings(embeddings2) |
|
|
|
|
|
return torch.matmul(embeddings1_norm, embeddings2_norm.transpose(0, 1)) |
|
|
|
|
|
def split_embeddings_by_shift_matrix(embeddings: torch.Tensor, shift_matrix: torch.Tensor) -> list: |
|
|
""" |
|
|
Split embeddings based on shift matrix |
|
|
|
|
|
Args: |
|
|
embeddings (torch.Tensor): Embeddings, shape [seq_len, hidden_dim] |
|
|
shift_matrix (torch.Tensor): shift matrix, shape [num_chunks, seq_len] |
|
|
|
|
|
Returns: |
|
|
list: List of split embeddings |
|
|
""" |
|
|
split_embeddings = [] |
|
|
num_chunks, seq_len = shift_matrix.shape |
|
|
|
|
|
for chunk_idx in range(num_chunks): |
|
|
mask = shift_matrix[chunk_idx] |
|
|
indices = torch.nonzero(mask, as_tuple=True)[0] |
|
|
|
|
|
if len(indices) > 0: |
|
|
chunk_embeddings = embeddings[indices] |
|
|
split_embeddings.append(chunk_embeddings) |
|
|
|
|
|
return split_embeddings |
|
|
|
|
|
def pool_embeddings(embeddings: torch.Tensor, method: str = 'mean') -> torch.Tensor: |
|
|
""" |
|
|
Pool embeddings |
|
|
|
|
|
Args: |
|
|
embeddings (torch.Tensor): Embeddings, shape [seq_len, hidden_dim] |
|
|
method (str): Pooling method, optional 'mean', 'max', 'first', 'last' |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Pooled vector, shape [hidden_dim] |
|
|
""" |
|
|
if method == 'mean': |
|
|
return torch.mean(embeddings, dim=0) |
|
|
elif method == 'max': |
|
|
return torch.max(embeddings, dim=0)[0] |
|
|
elif method == 'first': |
|
|
return embeddings[0] |
|
|
elif method == 'last': |
|
|
return embeddings[-1] |
|
|
else: |
|
|
raise ValueError(f"Unknown pooling method: {method}") |
|
|
|
|
|
def aggregate_chunk_embeddings(split_embeddings: list, method: str = 'mean') -> torch.Tensor: |
|
|
""" |
|
|
Aggregate chunk embeddings |
|
|
|
|
|
Args: |
|
|
split_embeddings (list): List of split embeddings |
|
|
method (str): Aggregation method |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Aggregated embeddings, shape [num_chunks, hidden_dim] |
|
|
""" |
|
|
if not split_embeddings: |
|
|
return torch.tensor([]) |
|
|
|
|
|
aggregated = [] |
|
|
for chunk_embeddings in split_embeddings: |
|
|
pooled = pool_embeddings(chunk_embeddings, method) |
|
|
aggregated.append(pooled) |
|
|
|
|
|
return torch.stack(aggregated) |
|
|
|
|
|
def safe_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: |
|
|
""" |
|
|
Safely convert tensor to numpy array |
|
|
|
|
|
Args: |
|
|
tensor (torch.Tensor): Input tensor |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Numpy array |
|
|
""" |
|
|
if tensor.requires_grad: |
|
|
tensor = tensor.detach() |
|
|
|
|
|
if tensor.is_cuda: |
|
|
tensor = tensor.cpu() |
|
|
|
|
|
return tensor.numpy() |
|
|
|
|
|
def ensure_tensor_on_device(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: |
|
|
""" |
|
|
Ensure tensor is on specified device |
|
|
|
|
|
Args: |
|
|
tensor (torch.Tensor): Input tensor |
|
|
device (torch.device): Target device |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Tensor on target device |
|
|
""" |
|
|
if tensor.device != device: |
|
|
tensor = tensor.to(device) |
|
|
return tensor |
|
|
|
|
|
def get_available_device() -> torch.device: |
|
|
""" |
|
|
Get available device |
|
|
|
|
|
Returns: |
|
|
torch.device: Available device |
|
|
""" |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device('cuda') |
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
|
return torch.device('mps') |
|
|
else: |
|
|
return torch.device('cpu') |
|
|
|
|
|
def print_tensor_info(tensor: torch.Tensor, name: str = "tensor"): |
|
|
""" |
|
|
Print tensor info |
|
|
|
|
|
Args: |
|
|
tensor (torch.Tensor): Input tensor |
|
|
name (str): Tensor name |
|
|
""" |
|
|
print(f"{name}:") |
|
|
print(f" Shape: {tensor.shape}") |
|
|
print(f" Data Type: {tensor.dtype}") |
|
|
print(f" Device: {tensor.device}") |
|
|
print(f" Requires Grad: {tensor.requires_grad}") |
|
|
if tensor.numel() > 0: |
|
|
print(f" Value Range: [{tensor.min().item():.6f}, {tensor.max().item():.6f}]") |
|
|
print(f" Mean: {tensor.mean().item():.6f}") |
|
|
print(f" Std Dev: {tensor.std().item():.6f}") |
|
|
|