Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| import torch | |
| class BaseRewardLoss(ABC): | |
| """ | |
| Base class for reward functions implementing a differentiable reward function for optimization. | |
| """ | |
| def __init__(self, name: str, weighting: float): | |
| self.name = name | |
| self.weighting = weighting | |
| def freeze_parameters(params: torch.nn.ParameterList): | |
| for param in params: | |
| param.requires_grad = False | |
| def get_image_features(self, image: torch.Tensor) -> torch.Tensor: | |
| pass | |
| def get_text_features(self, prompt: str) -> torch.Tensor: | |
| pass | |
| def compute_loss( | |
| self, image_features: torch.Tensor, text_features: torch.Tensor | |
| ) -> torch.Tensor: | |
| pass | |
| def process_features(self, features: torch.Tensor) -> torch.Tensor: | |
| features_normed = features / features.norm(dim=-1, keepdim=True) | |
| return features_normed | |
| def __call__(self, image: torch.Tensor, prompt: str) -> torch.Tensor: | |
| image_features = self.get_image_features(image) | |
| text_features = self.get_text_features(prompt) | |
| image_features_normed = self.process_features(image_features) | |
| text_features_normed = self.process_features(text_features) | |
| loss = self.compute_loss(image_features_normed, text_features_normed) | |
| return loss | |