Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import CLIPModel | |
| from rewards.base_reward import BaseRewardLoss | |
| class CLIPLoss(BaseRewardLoss): | |
| """CLIP reward loss function for optimization.""" | |
| def __init__( | |
| self, | |
| weigthing: float, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| cache_dir: str, | |
| tokenizer, | |
| memsave: bool = False, | |
| ): | |
| self.tokenizer = tokenizer | |
| self.clip_model = CLIPModel.from_pretrained( | |
| "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", | |
| cache_dir=cache_dir, | |
| ) | |
| # freeze all models parameters | |
| if memsave: | |
| import memsave_torch.nn | |
| self.clip_model = memsave_torch.nn.convert_to_memory_saving(self.clip_model) | |
| self.clip_model = self.clip_model.to(device, dtype=dtype) | |
| self.clip_model.eval() | |
| self.freeze_parameters(self.clip_model.parameters()) | |
| super().__init__("CLIP", weigthing) | |
| self.clip_model.gradient_checkpointing_enable() | |
| def get_image_features(self, image: torch.Tensor) -> torch.Tensor: | |
| clip_img_features = self.clip_model.get_image_features(image) | |
| return clip_img_features | |
| def get_text_features(self, prompt: str) -> torch.Tensor: | |
| prompt_token = self.tokenizer( | |
| prompt, return_tensors="pt", padding=True, max_length=77, truncation=True | |
| ).to("cuda") | |
| clip_text_features = self.clip_model.get_text_features(**prompt_token) | |
| return clip_text_features | |
| def compute_loss( | |
| self, image_features: torch.Tensor, text_features: torch.Tensor | |
| ) -> torch.Tensor: | |
| clip_loss = ( | |
| 100 | |
| - (image_features @ text_features.T).mean() | |
| * self.clip_model.logit_scale.exp() | |
| ) | |
| return clip_loss | |