Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import yaml | |
| import numpy as np | |
| from clip.encoders.image_encoder import ImageEncoder | |
| from clip.encoders.text_encoder import TextEncoder | |
| from helper.tokenizer import Tokenizer | |
| class CLIP(nn.Module): | |
| def __init__(self, config_path): | |
| super().__init__() | |
| with open(config_path, "r") as file: | |
| config = yaml.safe_load(file) | |
| self.image_encoder = ImageEncoder(**config["image_encoder"]) | |
| self.text_encoder = TextEncoder(**config["text_encoder"]) | |
| self.tokenizer = Tokenizer() | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| # initialize | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_normal_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| elif isinstance(module, nn.Conv2d): | |
| nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |
| elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): | |
| nn.init.constant_(module.weight, 1) | |
| nn.init.constant_(module.bias, 0) | |
| def loss(self, image, text): | |
| image_features, text_features = self(image, text, tokenize=False) | |
| # Normalize features | |
| image_features = F.normalize(image_features, dim=1) | |
| text_features = F.normalize(text_features, dim=1) | |
| # Cosine similarity as logits with learned temperature | |
| logits = torch.matmul(image_features, text_features.t()) * self.logit_scale.exp() | |
| labels = torch.arange(logits.shape[0], dtype=torch.long, device=logits.device) | |
| # Cross-entropy loss | |
| loss_i2t = F.cross_entropy(logits, labels) | |
| loss_t2i = F.cross_entropy(logits.t(), labels) | |
| return (loss_i2t + loss_t2i) / 2 | |
| def text_encode(self, text, tokenize=True): | |
| if tokenize: | |
| tokens = self.tokenizer.tokenize(text) | |
| else: | |
| tokens = text | |
| text_features = self.text_encoder(tokens) | |
| if text_features.dim() < 2: | |
| text_features = text_features.unsqueeze(0) | |
| return text_features | |
| def forward(self, image, text, tokenize=True): | |
| image_features = self.image_encoder(image) | |
| text_features = self.text_encoder(text, tokenize) | |
| if image_features.dim() < 2: | |
| image_features = image_features.unsqueeze(0) | |
| if text_features.dim() < 2: | |
| text_features = text_features.unsqueeze(0) | |
| return image_features, text_features |