| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor, einsum |
| | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union |
| | from einops import rearrange |
| | import math |
| | import comfy.ops |
| |
|
| | class LearnedPositionalEmbedding(nn.Module): |
| | """Used for continuous time""" |
| |
|
| | def __init__(self, dim: int): |
| | super().__init__() |
| | assert (dim % 2) == 0 |
| | half_dim = dim // 2 |
| | self.weights = nn.Parameter(torch.empty(half_dim)) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | x = rearrange(x, "b -> b 1") |
| | freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi |
| | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) |
| | fouriered = torch.cat((x, fouriered), dim=-1) |
| | return fouriered |
| |
|
| | def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: |
| | return nn.Sequential( |
| | LearnedPositionalEmbedding(dim), |
| | comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features), |
| | ) |
| |
|
| |
|
| | class NumberEmbedder(nn.Module): |
| | def __init__( |
| | self, |
| | features: int, |
| | dim: int = 256, |
| | ): |
| | super().__init__() |
| | self.features = features |
| | self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) |
| |
|
| | def forward(self, x: Union[List[float], Tensor]) -> Tensor: |
| | if not torch.is_tensor(x): |
| | device = next(self.embedding.parameters()).device |
| | x = torch.tensor(x, device=device) |
| | assert isinstance(x, Tensor) |
| | shape = x.shape |
| | x = rearrange(x, "... -> (...)") |
| | embedding = self.embedding(x) |
| | x = embedding.view(*shape, self.features) |
| | return x |
| |
|
| |
|
| | class Conditioner(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | output_dim: int, |
| | project_out: bool = False |
| | ): |
| |
|
| | super().__init__() |
| |
|
| | self.dim = dim |
| | self.output_dim = output_dim |
| | self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() |
| |
|
| | def forward(self, x): |
| | raise NotImplementedError() |
| |
|
| | class NumberConditioner(Conditioner): |
| | ''' |
| | Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings |
| | ''' |
| | def __init__(self, |
| | output_dim: int, |
| | min_val: float=0, |
| | max_val: float=1 |
| | ): |
| | super().__init__(output_dim, output_dim) |
| |
|
| | self.min_val = min_val |
| | self.max_val = max_val |
| |
|
| | self.embedder = NumberEmbedder(features=output_dim) |
| |
|
| | def forward(self, floats, device=None): |
| | |
| | floats = [float(x) for x in floats] |
| |
|
| | if device is None: |
| | device = next(self.embedder.parameters()).device |
| |
|
| | floats = torch.tensor(floats).to(device) |
| |
|
| | floats = floats.clamp(self.min_val, self.max_val) |
| |
|
| | normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) |
| |
|
| | |
| | embedder_dtype = next(self.embedder.parameters()).dtype |
| | normalized_floats = normalized_floats.to(embedder_dtype) |
| |
|
| | float_embeds = self.embedder(normalized_floats).unsqueeze(1) |
| |
|
| | return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] |
| |
|