Spaces:
Running on Zero
Running on Zero
| import torch | |
| import numpy as np | |
| def eps_from_v(z_0, z_t, sigma_t): | |
| return (z_t - z_0) / sigma_t | |
| def v_to_eps(v, t, x_t): | |
| """ | |
| function to compute the epsilon parametrization from the velocity field | |
| with x_t = t * x_0 + (1 - t) * x_1 with x_0 ~ N(0,I) | |
| """ | |
| eps_t = (1-t)*v + x_t | |
| return eps_t | |
| def clip_gradients(gradients, clip_value): | |
| grad_norm = gradients.norm(dim=2) | |
| mask = grad_norm > clip_value | |
| mask_exp = mask[:, :, None].expand_as(gradients) | |
| gradients[mask_exp] = ( | |
| gradients[mask_exp] | |
| / grad_norm[:, :, None].expand_as(gradients)[mask_exp] | |
| * clip_value | |
| ) | |
| return gradients | |
| class Adam: | |
| def __init__(self, parameters, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8): | |
| self.lr = lr | |
| self.beta1 = beta1 | |
| self.beta2 = beta2 | |
| self.epsilon = epsilon | |
| self.t = 0 | |
| self.m = torch.zeros_like(parameters) | |
| self.v = torch.zeros_like(parameters) | |
| def step(self, params, grad) -> torch.Tensor: | |
| self.t += 1 | |
| self.m = self.beta1 * self.m + (1 - self.beta1) * grad | |
| self.v = self.beta2 * self.v + (1 - self.beta2) * grad**2 | |
| m_hat = self.m / (1 - self.beta1**self.t) | |
| v_hat = self.v / (1 - self.beta2**self.t) | |
| # check if self.lr is callable | |
| if callable(self.lr): | |
| lr = self.lr(self.t - 1) | |
| else: | |
| lr = self.lr | |
| update = lr * m_hat / (torch.sqrt(v_hat) + self.epsilon) | |
| return params - update | |
| def make_cosine_decay_schedule( | |
| init_value: float, | |
| total_steps: int, | |
| alpha: float = 0.0, | |
| exponent: float = 1.0, | |
| warmup_steps=0, | |
| ): | |
| def schedule(count): | |
| if count < warmup_steps: | |
| # linear up | |
| return (init_value / warmup_steps) * count | |
| else: | |
| # half cosine down | |
| decay_steps = total_steps - warmup_steps | |
| count = min(count - warmup_steps, decay_steps) | |
| cosine_decay = 0.5 * (1 + np.cos(np.pi * count / decay_steps)) | |
| decayed = (1 - alpha) * cosine_decay**exponent + alpha | |
| return init_value * decayed | |
| return schedule | |
| def make_linear_decay_schedule( | |
| init_value: float, total_steps: int, final_value: float = 0, warmup_steps=0 | |
| ): | |
| def schedule(count): | |
| if count < warmup_steps: | |
| # linear up | |
| return (init_value / warmup_steps) * count | |
| else: | |
| # linear down | |
| decay_steps = total_steps - warmup_steps | |
| count = min(count - warmup_steps, decay_steps) | |
| return init_value - (init_value - final_value) * count / decay_steps | |
| return schedule | |
| def clip_norm_(tensor, max_norm): | |
| norm = tensor.norm() | |
| if norm > max_norm: | |
| tensor.mul_(max_norm / norm) | |
| def lr_warmup(step, warmup_steps): | |
| return min(1.0, step / max(warmup_steps, 1)) | |
| def linear_decay_lambda(step, warmup_steps, decay_steps, total_steps): | |
| if step < warmup_steps: | |
| min(1.0, step / max(warmup_steps, 1)) | |
| else: | |
| # linear down | |
| # decay_steps = total_steps - warmup_steps | |
| count = min(step - warmup_steps, decay_steps) | |
| return 1 - count / decay_steps | |