| import torch | |
| from typing import Tuple, Callable | |
| def hacer_nada(x: torch.Tensor, modo: str = None): | |
| return x | |
| def brujeria_mps(entrada, dim, indice): | |
| if entrada.shape[-1] == 1: | |
| return torch.gather(entrada.unsqueeze(-1), dim - 1 if dim < 0 else dim, indice.unsqueeze(-1)).squeeze(-1) | |
| else: | |
| return torch.gather(entrada, dim, indice) | |
| def emparejamiento_suave_aleatorio_2d( | |
| metrica: torch.Tensor, | |
| ancho: int, | |
| alto: int, | |
| paso_x: int, | |
| paso_y: int, | |
| radio: int, | |
| sin_aleatoriedad: bool = False, | |
| generador: torch.Generator = None | |
| ) -> Tuple[Callable, Callable]: | |
| lote, num_nodos, _ = metrica.shape | |
| if radio <= 0: | |
| return hacer_nada, hacer_nada | |
| recopilar = brujeria_mps if metrica.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| alto_paso_y, ancho_paso_x = alto // paso_y, ancho // paso_x | |
| if sin_aleatoriedad: | |
| indice_aleatorio = torch.zeros(alto_paso_y, ancho_paso_x, 1, device=metrica.device, dtype=torch.int64) | |
| else: | |
| indice_aleatorio = torch.randint(paso_y * paso_x, size=(alto_paso_y, ancho_paso_x, 1), device=generador.device, generator=generador).to(metrica.device) | |
| vista_buffer_indice = torch.zeros(alto_paso_y, ancho_paso_x, paso_y * paso_x, device=metrica.device, dtype=torch.int64) | |
| vista_buffer_indice.scatter_(dim=2, index=indice_aleatorio, src=-torch.ones_like(indice_aleatorio, dtype=indice_aleatorio.dtype)) | |
| vista_buffer_indice = vista_buffer_indice.view(alto_paso_y, ancho_paso_x, paso_y, paso_x).transpose(1, 2).reshape(alto_paso_y * paso_y, ancho_paso_x * paso_x) | |
| if (alto_paso_y * paso_y) < alto or (ancho_paso_x * paso_x) < ancho: | |
| buffer_indice = torch.zeros(alto, ancho, device=metrica.device, dtype=torch.int64) | |
| buffer_indice[:(alto_paso_y * paso_y), :(ancho_paso_x * paso_x)] = vista_buffer_indice | |
| else: | |
| buffer_indice = vista_buffer_indice | |
| indice_aleatorio = buffer_indice.reshape(1, -1, 1).argsort(dim=1) | |
| del buffer_indice, vista_buffer_indice | |
| num_destino = alto_paso_y * ancho_paso_x | |
| indices_a = indice_aleatorio[:, num_destino:, :] | |
| indices_b = indice_aleatorio[:, :num_destino, :] | |
| def dividir(x): | |
| canales = x.shape[-1] | |
| origen = recopilar(x, dim=1, index=indices_a.expand(lote, num_nodos - num_destino, canales)) | |
| destino = recopilar(x, dim=1, index=indices_b.expand(lote, num_destino, canales)) | |
| return origen, destino | |
| metrica = metrica / metrica.norm(dim=-1, keepdim=True) | |
| a, b = dividir(metrica) | |
| puntuaciones = a @ b.transpose(-1, -2) | |
| radio = min(a.shape[1], radio) | |
| nodo_max, nodo_indice = puntuaciones.max(dim=-1) | |
| indice_borde = nodo_max.argsort(dim=-1, descending=True)[..., None] | |
| indice_no_emparejado = indice_borde[..., radio:, :] | |
| indice_origen = indice_borde[..., :radio, :] | |
| indice_destino = recopilar(nodo_indice[..., None], dim=-2, index=indice_origen) | |
| def fusionar(x: torch.Tensor, modo="mean") -> torch.Tensor: | |
| origen, destino = dividir(x) | |
| n, t1, c = origen.shape | |
| no_emparejado = recopilar(origen, dim=-2, index=indice_no_emparejado.expand(n, t1 - radio, c)) | |
| origen = recopilar(origen, dim=-2, index=indice_origen.expand(n, radio, c)) | |
| destino = destino.scatter_reduce(-2, indice_destino.expand(n, radio, c), origen, reduce=modo) | |
| return torch.cat([no_emparejado, destino], dim=1) | |
| def desfusionar(x: torch.Tensor) -> torch.Tensor: | |
| longitud_no_emparejado = indice_no_emparejado.shape[1] | |
| no_emparejado, destino = x[..., :longitud_no_emparejado, :], x[..., longitud_no_emparejado:, :] | |
| _, _, c = no_emparejado.shape | |
| origen = recopilar(destino, dim=-2, index=indice_destino.expand(lote, radio, c)) | |
| salida = torch.zeros(lote, num_nodos, c, device=x.device, dtype=x.dtype) | |
| salida.scatter_(dim=-2, index=indices_b.expand(lote, num_destino, c), src=destino) | |
| salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_no_emparejado).expand(lote, longitud_no_emparejado, c), src=no_emparejado) | |
| salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_origen).expand(lote, radio, c), src=origen) | |
| return salida | |
| return fusionar, desfusionar |