| from torch import nn |
| import torch.nn.functional as F |
|
|
| from image_encoder import ImageEncoder |
| from text_encoder import TextEncoder |
| from projection_head import ProjectionHead |
| from configuration import CFG |
|
|
|
|
| class CLIPModel(nn.Module): |
| def __init__( |
| self, |
| temperature=CFG.temperature, |
| image_embedding=CFG.image_embedding, |
| text_embedding=CFG.text_embedding, |
| ): |
| super().__init__() |
| self.image_encoder = ImageEncoder() |
| self.text_encoder = TextEncoder() |
| self.image_projection = ProjectionHead(embedding_dim=image_embedding) |
| self.text_projection = ProjectionHead(embedding_dim=text_embedding) |
| self.temperature = temperature |
|
|
| def forward(self, batch): |
| |
| image_features = self.image_encoder(batch["image"]) |
| text_features = self.text_encoder( |
| input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] |
| ) |
| |
| image_embeddings = self.image_projection(image_features) |
| text_embeddings = self.text_projection(text_features) |
|
|
| |
| logits = (text_embeddings @ image_embeddings.T) / self.temperature |
| images_similarity = image_embeddings @ image_embeddings.T |
| texts_similarity = text_embeddings @ text_embeddings.T |
| targets = F.softmax( |
| (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 |
| ) |
| texts_loss = cross_entropy(logits, targets, reduction='none') |
| images_loss = cross_entropy(logits.T, targets.T, reduction='none') |
| loss = (images_loss + texts_loss) / 2.0 |
| return loss.mean() |
|
|
|
|
| def cross_entropy(preds, targets, reduction='none'): |
| log_softmax = nn.LogSoftmax(dim=-1) |
| loss = (-targets * log_softmax(preds)).sum(1) |
| if reduction == "none": |
| return loss |
| elif reduction == "mean": |
| return loss.mean() |