| from typing import Any |
| import pytorch_lightning as L |
| import torch |
| import torch.nn as nn |
| from hydra.utils import instantiate |
| import copy |
| import pandas as pd |
| import numpy as np |
| from tqdm import tqdm |
| from utils.manifolds import Sphere |
| from torch.func import jacrev, vjp, vmap |
| from torchdiffeq import odeint |
| from geoopt import ProductManifold, Euclidean |
| from models.samplers.riemannian_flow_sampler import ode_riemannian_flow_sampler |
|
|
|
|
| class DiffGeolocalizer(L.LightningModule): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.network = instantiate(cfg.network) |
| |
| self.input_dim = cfg.network.input_dim |
| self.train_noise_scheduler = instantiate(cfg.train_noise_scheduler) |
| self.inference_noise_scheduler = instantiate(cfg.inference_noise_scheduler) |
| self.data_preprocessing = instantiate(cfg.data_preprocessing) |
| self.cond_preprocessing = instantiate(cfg.cond_preprocessing) |
| self.preconditioning = instantiate(cfg.preconditioning) |
|
|
| self.ema_network = copy.deepcopy(self.network).requires_grad_(False) |
| self.ema_network.eval() |
| self.postprocessing = instantiate(cfg.postprocessing) |
| self.val_sampler = instantiate(cfg.val_sampler) |
| self.test_sampler = instantiate(cfg.test_sampler) |
| self.loss = instantiate(cfg.loss)( |
| self.train_noise_scheduler, |
| ) |
| self.val_metrics = instantiate(cfg.val_metrics) |
| self.test_metrics = instantiate(cfg.test_metrics) |
| self.manifold = instantiate(cfg.manifold) if hasattr(cfg, "manifold") else None |
|
|
| self.interpolant = cfg.interpolant |
|
|
| def training_step(self, batch, batch_idx): |
| with torch.no_grad(): |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| batch_size = batch["x_0"].shape[0] |
| loss = self.loss(self.preconditioning, self.network, batch).mean() |
| self.log( |
| "train/loss", |
| loss, |
| sync_dist=True, |
| on_step=True, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
| return loss |
|
|
| def on_before_optimizer_step(self, optimizer): |
| if self.global_step == 0: |
| no_grad = [] |
| for name, param in self.network.named_parameters(): |
| if param.grad is None: |
| no_grad.append(name) |
| if len(no_grad) > 0: |
| print("Parameters without grad:") |
| print(no_grad) |
|
|
| def on_validation_start(self): |
| self.validation_generator = torch.Generator(device=self.device).manual_seed( |
| 3407 |
| ) |
| self.validation_generator_ema = torch.Generator(device=self.device).manual_seed( |
| 3407 |
| ) |
|
|
| def validation_step(self, batch, batch_idx): |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| batch_size = batch["x_0"].shape[0] |
| loss = self.loss( |
| self.preconditioning, |
| self.network, |
| batch, |
| generator=self.validation_generator, |
| ).mean() |
| self.log( |
| "val/loss", |
| loss, |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
| if hasattr(self, "ema_model"): |
| loss_ema = self.loss( |
| self.preconditioning, |
| self.ema_network, |
| batch, |
| generator=self.validation_generator_ema, |
| ).mean() |
| self.log( |
| "val/loss_ema", |
| loss_ema, |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def on_test_start(self): |
| self.test_generator = torch.Generator(device=self.device).manual_seed(3407) |
|
|
| def test_step_simple(self, batch, batch_idx): |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| batch_size = batch["x_0"].shape[0] |
| if isinstance(self.manifold, Sphere): |
| x_N = self.manifold.random_base( |
| batch_size, |
| self.input_dim, |
| device=self.device, |
| ) |
| x_N = x_N.reshape(batch_size, self.input_dim) |
| else: |
| x_N = torch.randn( |
| batch_size, |
| self.input_dim, |
| device=self.device, |
| generator=self.test_generator, |
| ) |
| cond = batch[self.cfg.cond_preprocessing.output_key] |
|
|
| samples = self.sample( |
| x_N=x_N, |
| cond=cond, |
| stage="val", |
| generator=self.test_generator, |
| cfg=self.cfg.cfg_rate, |
| ) |
| self.test_metrics.update({"gps": samples}, batch) |
| if self.cfg.compute_nll: |
| nll = -self.compute_exact_loglikelihood(batch, cfg=0).mean() |
| self.log( |
| "test/NLL", |
| nll, |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
|
|
| def test_best_nll(self, batch, batch_idx): |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| batch_size = batch["x_0"].shape[0] |
| num_sample_per_cond = 32 |
| if isinstance(self.manifold, Sphere): |
| x_N = self.manifold.random_base( |
| batch_size * num_sample_per_cond, |
| self.input_dim, |
| device=self.device, |
| ) |
| x_N = x_N.reshape(batch_size * num_sample_per_cond, self.input_dim) |
| else: |
| x_N = torch.randn( |
| batch_size * num_sample_per_cond, |
| self.input_dim, |
| device=self.device, |
| generator=self.test_generator, |
| ) |
| cond = ( |
| batch[self.cfg.cond_preprocessing.output_key] |
| .unsqueeze(1) |
| .repeat(1, num_sample_per_cond, 1) |
| .view(-1, batch[self.cfg.cond_preprocessing.output_key].shape[-1]) |
| ) |
| samples = self.sample_distribution( |
| x_N, |
| cond, |
| sampling_batch_size=32768, |
| stage="val", |
| generator=self.test_generator, |
| cfg=0, |
| ) |
| samples = samples.view(batch_size * num_sample_per_cond, -1) |
| batch_swarm = {"gps": samples, "emb": cond} |
| nll_batch = -self.compute_exact_loglikelihood(batch_swarm, cfg=0) |
| nll_batch = nll_batch.view(batch_size, num_sample_per_cond, -1) |
| nll_best = nll_batch[ |
| torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1) |
| ] |
| self.log( |
| "test/best_nll", |
| nll_best.mean(), |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| ) |
| samples = samples.view(batch_size, num_sample_per_cond, -1)[ |
| torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1) |
| ] |
| self.test_metrics.update({"gps": samples}, batch) |
|
|
| def test_step(self, batch, batch_idx): |
| if self.cfg.compute_swarms: |
| self.test_best_nll(batch, batch_idx) |
| else: |
| self.test_step_simple(batch, batch_idx) |
|
|
| def on_test_epoch_end(self): |
| metrics = self.test_metrics.compute() |
| for metric_name, metric_value in metrics.items(): |
| self.log( |
| f"test/{metric_name}", |
| metric_value, |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| ) |
|
|
| def configure_optimizers(self): |
| if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay: |
| parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm]) |
| parameters_names_wd = [ |
| name for name in parameters_names_wd if "bias" not in name |
| ] |
| optimizer_grouped_parameters = [ |
| { |
| "params": [ |
| p |
| for n, p in self.network.named_parameters() |
| if n in parameters_names_wd |
| ], |
| "weight_decay": self.cfg.optimizer.optim.weight_decay, |
| "layer_adaptation": True, |
| }, |
| { |
| "params": [ |
| p |
| for n, p in self.network.named_parameters() |
| if n not in parameters_names_wd |
| ], |
| "weight_decay": 0.0, |
| "layer_adaptation": False, |
| }, |
| ] |
| optimizer = instantiate( |
| self.cfg.optimizer.optim, optimizer_grouped_parameters |
| ) |
| else: |
| optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters()) |
| if "lr_scheduler" in self.cfg: |
| scheduler = instantiate(self.cfg.lr_scheduler)(optimizer) |
| return [optimizer], [{"scheduler": scheduler, "interval": "step"}] |
| else: |
| return optimizer |
|
|
| def lr_scheduler_step(self, scheduler, metric): |
| scheduler.step(self.global_step) |
|
|
| def sample( |
| self, |
| batch_size=None, |
| cond=None, |
| x_N=None, |
| num_steps=None, |
| stage="test", |
| cfg=0, |
| generator=None, |
| return_trajectories=False, |
| postprocessing=True, |
| ): |
| if x_N is None: |
| assert batch_size is not None |
| if isinstance(self.manifold, Sphere): |
| x_N = self.manifold.random_base( |
| batch_size, self.input_dim, device=self.device |
| ) |
| x_N = x_N.reshape(batch_size, self.input_dim) |
| else: |
| x_N = torch.randn(batch_size, self.input_dim, device=self.device) |
| batch = {"y": x_N} |
| if stage == "val": |
| sampler = self.val_sampler |
| elif stage == "test": |
| sampler = self.test_sampler |
| else: |
| raise ValueError(f"Unknown stage {stage}") |
| batch[self.cfg.cond_preprocessing.input_key] = cond |
| batch = self.cond_preprocessing(batch, device=self.device) |
| if num_steps is None: |
| output = sampler( |
| self.ema_model, |
| batch, |
| conditioning_keys=self.cfg.cond_preprocessing.output_key, |
| scheduler=self.inference_noise_scheduler, |
| cfg_rate=cfg, |
| generator=generator, |
| return_trajectories=return_trajectories, |
| ) |
| else: |
| output = sampler( |
| self.ema_model, |
| batch, |
| conditioning_keys=self.cfg.cond_preprocessing.output_key, |
| scheduler=self.inference_noise_scheduler, |
| num_steps=num_steps, |
| cfg_rate=cfg, |
| generator=generator, |
| return_trajectories=return_trajectories, |
| ) |
| if return_trajectories: |
| return ( |
| self.postprocessing(output[0]) if postprocessing else output[0], |
| [ |
| self.postprocessing(frame) if postprocessing else frame |
| for frame in output[1] |
| ], |
| ) |
| else: |
| return self.postprocessing(output) if postprocessing else output |
|
|
| def sample_distribution( |
| self, |
| x_N, |
| cond, |
| sampling_batch_size=2048, |
| num_steps=None, |
| stage="test", |
| cfg=0, |
| generator=None, |
| return_trajectories=False, |
| ): |
| if return_trajectories: |
| x_0 = [] |
| trajectories = [] |
| i = -1 |
| for i in range(x_N.shape[0] // sampling_batch_size): |
| x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size] |
| cond_batch = cond[ |
| i * sampling_batch_size : (i + 1) * sampling_batch_size |
| ] |
| out, trajectories = self.sample( |
| cond=cond_batch, |
| x_N=x_N_batch, |
| num_steps=num_steps, |
| stage=stage, |
| cfg=cfg, |
| generator=generator, |
| return_trajectories=return_trajectories, |
| ) |
| x_0.append(out) |
| trajectories.append(trajectories) |
| if x_N.shape[0] % sampling_batch_size != 0: |
| x_N_batch = x_N[(i + 1) * sampling_batch_size :] |
| cond_batch = cond[(i + 1) * sampling_batch_size :] |
| out, trajectories = self.sample( |
| cond=cond_batch, |
| x_N=x_N_batch, |
| num_steps=num_steps, |
| stage=stage, |
| cfg=cfg, |
| generator=generator, |
| return_trajectories=return_trajectories, |
| ) |
| x_0.append(out) |
| trajectories.append(trajectories) |
| x_0 = torch.cat(x_0, dim=1) |
| trajectories = [torch.cat(frame, dim=1) for frame in trajectories] |
| return x_0, trajectories |
| else: |
| x_0 = [] |
| i = -1 |
| for i in range(x_N.shape[0] // sampling_batch_size): |
| x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size] |
| cond_batch = cond[ |
| i * sampling_batch_size : (i + 1) * sampling_batch_size |
| ] |
| out = self.sample( |
| cond=cond_batch, |
| x_N=x_N_batch, |
| num_steps=num_steps, |
| stage=stage, |
| cfg=cfg, |
| generator=generator, |
| return_trajectories=return_trajectories, |
| ) |
| x_0.append(out) |
| if x_N.shape[0] % sampling_batch_size != 0: |
| x_N_batch = x_N[(i + 1) * sampling_batch_size :] |
| cond_batch = cond[(i + 1) * sampling_batch_size :] |
| out = self.sample( |
| cond=cond_batch, |
| x_N=x_N_batch, |
| num_steps=num_steps, |
| stage=stage, |
| cfg=cfg, |
| generator=generator, |
| return_trajectories=return_trajectories, |
| ) |
| x_0.append(out) |
| x_0 = torch.cat(x_0, dim=0) |
| return x_0 |
|
|
| def model(self, *args, **kwargs): |
| return self.preconditioning(self.network, *args, **kwargs) |
|
|
| def ema_model(self, *args, **kwargs): |
| return self.preconditioning(self.ema_network, *args, **kwargs) |
|
|
| def compute_exact_loglikelihood( |
| self, |
| batch=None, |
| x_1=None, |
| cond=None, |
| t1=1.0, |
| num_steps=1000, |
| rademacher=False, |
| data_preprocessing=True, |
| cfg=0, |
| ): |
| nfe = [0] |
| if batch is None: |
| batch = {"x_0": x_1, "emb": cond} |
| if data_preprocessing: |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| timesteps = self.inference_noise_scheduler( |
| torch.linspace(0, t1, 2).to(batch["x_0"]) |
| ) |
| with torch.inference_mode(mode=False): |
|
|
| def odefunc(t, tensor): |
| nfe[0] += 1 |
| t = t.to(tensor) |
| gamma = self.inference_noise_scheduler(t) |
| x = tensor[..., : self.input_dim] |
| y = batch["emb"] |
|
|
| def vecfield(x, y): |
| if cfg > 0: |
| batch_vecfield = { |
| "y": x, |
| "emb": y, |
| "gamma": gamma.reshape(-1), |
| } |
| model_output_cond = self.ema_model(batch_vecfield) |
| batch_vecfield_uncond = { |
| "y": x, |
| "emb": torch.zeros_like(y), |
| "gamma": gamma.reshape(-1), |
| } |
| model_output_uncond = self.ema_model(batch_vecfield_uncond) |
| model_output = model_output_cond + cfg * ( |
| model_output_cond - model_output_uncond |
| ) |
|
|
| else: |
| batch_vecfield = { |
| "y": x, |
| "emb": y, |
| "gamma": gamma.reshape(-1), |
| } |
| model_output = self.ema_model(batch_vecfield) |
|
|
| if self.interpolant == "flow_matching": |
| d_gamma = self.inference_noise_scheduler.derivative(t).reshape( |
| -1, 1 |
| ) |
| return d_gamma * model_output |
| elif self.interpolant == "diffusion": |
| alpha_t = self.inference_noise_scheduler.alpha(t).reshape(-1, 1) |
| return ( |
| -1 / 2 * (alpha_t * x - torch.abs(alpha_t) * model_output) |
| ) |
| else: |
| raise ValueError(f"Unknown interpolant {self.interpolant}") |
|
|
| if rademacher: |
| v = torch.randint_like(x, 2) * 2 - 1 |
| else: |
| v = None |
| dx, div = output_and_div(vecfield, x, y, v=v) |
| div = div.reshape(-1, 1) |
| del t, x |
| return torch.cat([dx, div], dim=-1) |
|
|
| x_1 = batch["x_0"] |
| state1 = torch.cat([x_1, torch.zeros_like(x_1[..., :1])], dim=-1) |
| with torch.no_grad(): |
| if False and isinstance(self.manifold, Sphere): |
| print("Riemannian flow sampler") |
| product_man = ProductManifold( |
| (self.manifold, self.input_dim), (Euclidean(), 1) |
| ) |
| state0 = ode_riemannian_flow_sampler( |
| odefunc, |
| state1, |
| manifold=product_man, |
| scheduler=self.inference_noise_scheduler, |
| num_steps=num_steps, |
| ) |
| else: |
| print("ODE solver") |
| state0 = odeint( |
| odefunc, |
| state1, |
| t=torch.linspace(0, t1, 2).to(batch["x_0"]), |
| atol=1e-6, |
| rtol=1e-6, |
| method="dopri5", |
| options={"min_step": 1e-5}, |
| )[-1] |
| x_0, logdetjac = state0[..., : self.input_dim], state0[..., -1] |
| if self.manifold is not None: |
| x_0 = self.manifold.projx(x_0) |
| logp0 = self.manifold.base_logprob(x_0) |
| else: |
| logp0 = ( |
| -1 / 2 * (x_0**2).sum(dim=-1) |
| - self.input_dim |
| * torch.log(torch.tensor(2 * np.pi, device=x_0.device)) |
| / 2 |
| ) |
| print(f"nfe: {nfe[0]}") |
| logp1 = logp0 + logdetjac |
| logp1 = logp1 / (self.input_dim * np.log(2)) |
| return logp1 |
|
|
|
|
| def get_parameter_names(model, forbidden_layer_types): |
| """ |
| Returns the names of the model parameters that are not inside a forbidden layer. |
| Taken from HuggingFace transformers. |
| """ |
| result = [] |
| for name, child in model.named_children(): |
| result += [ |
| f"{name}.{n}" |
| for n in get_parameter_names(child, forbidden_layer_types) |
| if not isinstance(child, tuple(forbidden_layer_types)) |
| ] |
| |
| result += list(model._parameters.keys()) |
| return result |
|
|
|
|
| |
| def div_fn(u): |
| """Accepts a function u:R^D -> R^D.""" |
| J = jacrev(u, argnums=0) |
| return lambda x, y: torch.trace(J(x, y).squeeze(0)) |
|
|
|
|
| def output_and_div(vecfield, x, y, v=None): |
| if v is None: |
| dx = vecfield(x, y) |
| div = vmap(div_fn(vecfield))(x, y) |
| else: |
| vecfield_x = lambda x: vecfield(x, y) |
| dx, vjpfunc = vjp(vecfield_x, x) |
| vJ = vjpfunc(v)[0] |
| div = torch.sum(vJ * v, dim=-1) |
| return dx, div |
|
|
|
|
| class VonFisherGeolocalizer(L.LightningModule): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.network = instantiate(cfg.network) |
| |
| self.input_dim = cfg.network.input_dim |
| self.data_preprocessing = instantiate(cfg.data_preprocessing) |
| self.cond_preprocessing = instantiate(cfg.cond_preprocessing) |
| self.preconditioning = instantiate(cfg.preconditioning) |
|
|
| self.ema_network = copy.deepcopy(self.network).requires_grad_(False) |
| self.ema_network.eval() |
| self.postprocessing = instantiate(cfg.postprocessing) |
| self.val_sampler = instantiate(cfg.val_sampler) |
| self.test_sampler = instantiate(cfg.test_sampler) |
| self.loss = instantiate(cfg.loss)() |
| self.val_metrics = instantiate(cfg.val_metrics) |
| self.test_metrics = instantiate(cfg.test_metrics) |
|
|
| def training_step(self, batch, batch_idx): |
| with torch.no_grad(): |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| batch_size = batch["x_0"].shape[0] |
| loss = self.loss(self.preconditioning, self.network, batch).mean() |
| self.log( |
| "train/loss", |
| loss, |
| sync_dist=True, |
| on_step=True, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
| return loss |
|
|
| def on_before_optimizer_step(self, optimizer): |
| if self.global_step == 0: |
| no_grad = [] |
| for name, param in self.network.named_parameters(): |
| if param.grad is None: |
| no_grad.append(name) |
| if len(no_grad) > 0: |
| print("Parameters without grad:") |
| print(no_grad) |
|
|
| def on_validation_start(self): |
| self.validation_generator = torch.Generator(device=self.device).manual_seed( |
| 3407 |
| ) |
| self.validation_generator_ema = torch.Generator(device=self.device).manual_seed( |
| 3407 |
| ) |
|
|
| def validation_step(self, batch, batch_idx): |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| batch_size = batch["x_0"].shape[0] |
| loss = self.loss( |
| self.preconditioning, |
| self.network, |
| batch, |
| generator=self.validation_generator, |
| ).mean() |
| self.log( |
| "val/loss", |
| loss, |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
| if hasattr(self, "ema_model"): |
| loss_ema = self.loss( |
| self.preconditioning, |
| self.ema_network, |
| batch, |
| generator=self.validation_generator_ema, |
| ).mean() |
| self.log( |
| "val/loss_ema", |
| loss_ema, |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
|
|
| def on_test_start(self): |
| self.test_generator = torch.Generator(device=self.device).manual_seed(3407) |
|
|
| def test_step(self, batch, batch_idx): |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| batch_size = batch["x_0"].shape[0] |
| cond = batch[self.cfg.cond_preprocessing.output_key] |
|
|
| samples = self.sample(cond=cond, stage="test") |
| self.test_metrics.update({"gps": samples}, batch) |
| nll = -self.compute_exact_loglikelihood(batch).mean() |
| self.log( |
| "test/NLL", |
| nll, |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
|
|
| def on_test_epoch_end(self): |
| metrics = self.test_metrics.compute() |
| for metric_name, metric_value in metrics.items(): |
| self.log( |
| f"test/{metric_name}", |
| metric_value, |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| ) |
|
|
| def configure_optimizers(self): |
| if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay: |
| parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm]) |
| parameters_names_wd = [ |
| name for name in parameters_names_wd if "bias" not in name |
| ] |
| optimizer_grouped_parameters = [ |
| { |
| "params": [ |
| p |
| for n, p in self.network.named_parameters() |
| if n in parameters_names_wd |
| ], |
| "weight_decay": self.cfg.optimizer.optim.weight_decay, |
| "layer_adaptation": True, |
| }, |
| { |
| "params": [ |
| p |
| for n, p in self.network.named_parameters() |
| if n not in parameters_names_wd |
| ], |
| "weight_decay": 0.0, |
| "layer_adaptation": False, |
| }, |
| ] |
| optimizer = instantiate( |
| self.cfg.optimizer.optim, optimizer_grouped_parameters |
| ) |
| else: |
| optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters()) |
| if "lr_scheduler" in self.cfg: |
| scheduler = instantiate(self.cfg.lr_scheduler)(optimizer) |
| return [optimizer], [{"scheduler": scheduler, "interval": "step"}] |
| else: |
| return optimizer |
|
|
| def lr_scheduler_step(self, scheduler, metric): |
| scheduler.step(self.global_step) |
|
|
| def sample( |
| self, |
| batch_size=None, |
| cond=None, |
| postprocessing=True, |
| stage="val", |
| ): |
| batch = {} |
| if stage == "val": |
| sampler = self.val_sampler |
| elif stage == "test": |
| sampler = self.test_sampler |
| else: |
| raise ValueError(f"Unknown stage {stage}") |
| batch[self.cfg.cond_preprocessing.input_key] = cond |
| batch = self.cond_preprocessing(batch, device=self.device) |
| output = sampler( |
| self.ema_model, |
| batch, |
| ) |
| return self.postprocessing(output) if postprocessing else output |
|
|
| def model(self, *args, **kwargs): |
| return self.preconditioning(self.network, *args, **kwargs) |
|
|
| def ema_model(self, *args, **kwargs): |
| return self.preconditioning(self.ema_network, *args, **kwargs) |
|
|
| def compute_exact_loglikelihood( |
| self, |
| batch=None, |
| ): |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| return -self.loss(self.preconditioning, self.ema_network, batch) |
|
|
|
|
| class RandomGeolocalizer(L.LightningModule): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.test_metrics = instantiate(cfg.test_metrics) |
| self.data_preprocessing = instantiate(cfg.data_preprocessing) |
| self.cond_preprocessing = instantiate(cfg.cond_preprocessing) |
| self.postprocessing = instantiate(cfg.postprocessing) |
|
|
| def test_step(self, batch, batch_idx): |
| batch = self.data_preprocessing(batch) |
| batch = self.cond_preprocessing(batch) |
| batch_size = batch["x_0"].shape[0] |
| samples = torch.randn(batch_size, 3, device=self.device) |
| samples = samples / samples.norm(dim=-1, keepdim=True) |
| samples = self.postprocessing(samples) |
| self.test_metrics.update({"gps": samples}, batch) |
|
|
| def on_test_epoch_end(self): |
| metrics = self.test_metrics.compute() |
| for metric_name, metric_value in metrics.items(): |
| self.log( |
| f"test/{metric_name}", |
| metric_value, |
| sync_dist=True, |
| on_step=False, |
| on_epoch=True, |
| ) |
|
|