Spaces:
Paused
Paused
| from dataclasses import dataclass, field | |
| from typing import Any, List, Optional | |
| import alpha_clip | |
| import torch | |
| import torch.nn as nn | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from torchvision.transforms import Normalize | |
| from spar3d.models.network import get_activation | |
| from spar3d.models.utils import BaseModule | |
| OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) | |
| OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) | |
| class HeadSpec: | |
| name: str | |
| out_channels: int | |
| n_hidden_layers: int | |
| output_activation: Optional[str] = None | |
| output_bias: float = 0.0 | |
| add_to_decoder_features: bool = False | |
| shape: Optional[list[int]] = None | |
| distribution_eval: str = "sample" | |
| class ClipBasedHeadEstimator(BaseModule): | |
| class Config(BaseModule.Config): | |
| model: str = "ViT-L/14@336px" | |
| distribution: str = "beta" | |
| # ["mean", "mode", "sample", "sample_mean"] | |
| distribution_eval: str = "mode" | |
| activation: str = "relu" | |
| hidden_features: int = 512 | |
| heads: List[HeadSpec] = field(default_factory=lambda: []) | |
| cfg: Config | |
| def configure(self): | |
| self.model, _ = alpha_clip.load( | |
| self.cfg.model, | |
| ) # change to your own ckpt path | |
| self.model.eval() | |
| if not hasattr(self.model.visual, "input_resolution"): | |
| self.img_size = 224 | |
| else: | |
| self.img_size = self.model.visual.input_resolution | |
| # Check if img_size is subscribable and pick the first element | |
| if hasattr(self.img_size, "__getitem__"): | |
| self.img_size = self.img_size[0] | |
| # Do not add the weights in self.model to the optimizer | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| assert len(self.cfg.heads) > 0 | |
| heads = {} | |
| for head in self.cfg.heads: | |
| head_layers = [] | |
| in_feature = self.model.visual.output_dim | |
| for i in range(head.n_hidden_layers): | |
| head_layers += [ | |
| nn.Linear( | |
| in_feature if i == 0 else self.cfg.hidden_features, | |
| self.cfg.hidden_features, | |
| ), | |
| self.make_activation(self.cfg.activation), | |
| ] | |
| head_layers = [nn.Sequential(*head_layers)] | |
| head_layers += [ | |
| nn.Sequential( | |
| nn.Linear( | |
| self.cfg.hidden_features, | |
| self.cfg.hidden_features, | |
| ), | |
| self.make_activation(self.cfg.activation), | |
| nn.Linear(self.cfg.hidden_features, 1), | |
| ) | |
| for _ in range(2) | |
| ] | |
| heads[head.name] = nn.ModuleList(head_layers) | |
| self.heads = nn.ModuleDict(heads) | |
| def make_activation(self, activation): | |
| if activation == "relu": | |
| return nn.ReLU(inplace=True) | |
| elif activation == "silu": | |
| return nn.SiLU(inplace=True) | |
| else: | |
| raise NotImplementedError | |
| def forward( | |
| self, | |
| cond_image: Float[Tensor, "B 1 H W 4"], | |
| sample: bool = True, | |
| ) -> dict[str, Any]: | |
| # Run the model | |
| # Resize cond_image to 224 | |
| cond_image = cond_image.flatten(0, 1) | |
| cond_image = nn.functional.interpolate( | |
| cond_image.permute(0, 3, 1, 2), | |
| size=(self.img_size, self.img_size), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| mask = cond_image[:, 3:4] | |
| cond_image = cond_image[:, :3] * mask | |
| cond_image = Normalize( | |
| mean=OPENAI_DATASET_MEAN, | |
| std=OPENAI_DATASET_STD, | |
| )(cond_image) | |
| mask = Normalize(0.5, 0.26)(mask).half() | |
| image_features = self.model.visual(cond_image.half(), mask).float() | |
| # Run the heads | |
| outputs = {} | |
| for head_dict in self.cfg.heads: | |
| head_name = head_dict.name | |
| shared_head, d1_h, d2_h = self.heads[head_name] | |
| shared_features = shared_head(image_features) | |
| d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]] | |
| if self.cfg.distribution == "normal": | |
| mean = d1 | |
| var = d2 | |
| if mean.shape[-1] == 1: | |
| outputs[head_name] = torch.distributions.Normal( | |
| mean + head_dict.output_bias, | |
| torch.nn.functional.softplus(var), | |
| ) | |
| else: | |
| outputs[head_name] = torch.distributions.MultivariateNormal( | |
| mean + head_dict.output_bias, | |
| torch.nn.functional.softplus(var).diag_embed(), | |
| ) | |
| elif self.cfg.distribution == "beta": | |
| outputs[head_name] = torch.distributions.Beta( | |
| torch.nn.functional.softplus(d1 + head_dict.output_bias), | |
| torch.nn.functional.softplus(d2 + head_dict.output_bias), | |
| ) | |
| else: | |
| raise NotImplementedError | |
| if sample: | |
| for head_dict in self.cfg.heads: | |
| head_name = head_dict.name | |
| dist = outputs[head_name] | |
| if head_dict.distribution_eval == "mean": | |
| out = dist.mean | |
| elif head_dict.distribution_eval == "mode": | |
| out = dist.mode | |
| elif head_dict.distribution_eval == "sample_mean": | |
| out = dist.sample([10]).mean(-1) | |
| else: | |
| # use rsample if gradient is needed | |
| out = dist.rsample() if self.training else dist.sample() | |
| outputs[head_name] = get_activation(head_dict.output_activation)(out) | |
| outputs[f"{head_name}_dist"] = dist | |
| for head in self.cfg.heads: | |
| if head.shape: | |
| if not sample: | |
| raise ValueError( | |
| "Cannot reshape non-sampled probabilisitic outputs" | |
| ) | |
| outputs[head.name] = outputs[head.name].reshape(*head.shape) | |
| if head.add_to_decoder_features: | |
| outputs[f"decoder_{head.name}"] = outputs[head.name] | |
| del outputs[head.name] | |
| return outputs | |