""" Style vector utilities. Helper functions for manipulating, comparing, and persisting style vectors. """ import torch import torch.nn.functional as F from typing import List, Optional def cosine_similarity(vec_a: torch.Tensor, vec_b: torch.Tensor) -> float: """Compute cosine similarity between two style vectors.""" if vec_a.dim() == 1: vec_a = vec_a.unsqueeze(0) if vec_b.dim() == 1: vec_b = vec_b.unsqueeze(0) sim = F.cosine_similarity(vec_a, vec_b, dim=-1) return sim.item() def average_style_vectors(vectors: List[torch.Tensor]) -> torch.Tensor: """Compute the mean style vector from a list of vectors.""" if not vectors: raise ValueError("Cannot average empty list of vectors") stacked = torch.stack(vectors, dim=0) mean_vec = stacked.mean(dim=0) # L2 normalise the result return F.normalize(mean_vec, p=2, dim=-1) def save_style_vector(vector: torch.Tensor, path: str) -> None: """Persist a style vector to disk.""" torch.save(vector.detach().cpu(), path) def load_style_vector(path: str) -> torch.Tensor: """Load a style vector from disk.""" return torch.load(path, map_location="cpu", weights_only=True)