Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import pyrootutils | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True) | |
| from src.train.dist_utils import concat_all_gather | |
| def cosine_loss(rec, target): | |
| target = target / target.norm(dim=-1, keepdim=True) | |
| rec = rec / rec.norm(dim=-1, keepdim=True) | |
| rec_loss = (1 - (target * rec).sum(-1)).mean() | |
| return rec_loss | |
| def contrastive_loss(image_feats, text_feats, logit_scale): | |
| image_feats = image_feats.unsqueeze(1).contiguous() | |
| image_feats_all = concat_all_gather(image_feats) # [batch_size*num_gpu, num_query_tokens, embed_dim] | |
| text_feats_all = concat_all_gather(text_feats) # [batch_size*num_gpu, embed_dim] | |
| sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feats_all.unsqueeze(-1)).squeeze() | |
| # [batch_size, batch_size*num_gpu, num_query_tokens] | |
| # image-text similarity: aggregate across all query tokens | |
| # sim_i2t, _ = sim_q2t.max(-1) | |
| # sim_i2t = sim_q2t.mean(-1) | |
| sim_i2t = sim_q2t | |
| sim_i2t = sim_i2t / logit_scale | |
| # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] | |
| sim_t2q = torch.matmul(text_feats.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)).squeeze() | |
| # print(image_feats_all.shape, text_feat_all.shape, sim_q2t.shape, sim_t2q.shape) | |
| # text-image similarity: aggregate across all query tokens | |
| # sim_t2i, _ = sim_t2q.max(-1) | |
| # sim_t2i = sim_t2q.mean(-1) | |
| sim_t2i = sim_t2q | |
| sim_t2i = sim_t2i / logit_scale # [batch_size, batch_size*num_gpu] | |
| rank = dist.get_rank() | |
| bs = image_feats.size(0) | |
| targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image_feats.device) | |
| loss_itc = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + | |
| F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2 | |
| i2t_acc = (sim_i2t.argmax(-1) == targets).sum() / len(sim_i2t) | |
| t2i_acc = (sim_t2i.argmax(-1) == targets).sum() / len(sim_t2i) | |
| return loss_itc, i2t_acc, t2i_acc | |
| class DiscreteModleOnlyDistill(nn.Module): | |
| def __init__(self, | |
| qformer, | |
| quantizer, | |
| distiller=None, | |
| loss_type='cosine', | |
| scale_commit_loss=1.0, | |
| freeze_qformer=False) -> None: | |
| super().__init__() | |
| self.qformer = qformer | |
| self.quantizer = quantizer | |
| self.distiller = distiller | |
| self.loss_type = loss_type | |
| self.scale_commit_loss = scale_commit_loss | |
| self.freeze_qformer = freeze_qformer | |
| if freeze_qformer: | |
| self.qformer.requires_grad_(False) | |
| def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
| if self.freeze_qformer: | |
| with torch.no_grad(): | |
| qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
| else: | |
| qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
| quantizer_output = self.quantizer(qforemr_embeds) | |
| recon_embeds = self.distiller(quantizer_output['quant_embeds']) | |
| if self.loss_type == 'cosine': | |
| distill_loss = cosine_loss(recon_embeds, image_embeds) | |
| else: | |
| raise NotImplementedError | |
| total_loss = distill_loss + self.scale_commit_loss * \ | |
| quantizer_output['commit_loss'] | |
| return { | |
| 'total_loss': total_loss, | |
| 'distill_loss': distill_loss, | |
| 'commit_loss': quantizer_output['commit_loss'], | |
| 'indices': quantizer_output['indices'] | |
| } | |
| def encode_image_embeds(self, image_embeds): | |
| qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
| quantizer_output = self.quantizer(qforemr_embeds) | |
| output_embeds = quantizer_output['quant_embeds'] | |
| if self.distiller is not None: | |
| output_embeds = self.distiller(output_embeds) | |
| return output_embeds | |
| def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs): | |
| model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs) | |
| if pretrained_model_path is not None: | |
| ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
| missing, unexpected = model.load_state_dict(ckpt, strict=False) | |
| print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
| return model | |
| class DiscreteModleIdentity(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.model = nn.Identity() | |
| def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
| return | |
| def encode_image_embeds(self, image_embeds): | |
| return self.model(image_embeds) | |
| class DiscreteModleStageOneContrastive(nn.Module): | |
| def __init__(self, qformer, quantizer=None, distiller=None, projection_dim=1024, | |
| image_cls_token_type='last') -> None: | |
| super().__init__() | |
| self.qformer = qformer | |
| self.quantizer = quantizer | |
| self.distiller = distiller | |
| self.image_cls_token_type = image_cls_token_type | |
| self.logit_scale = nn.Parameter(0.07 * torch.ones([])) | |
| self.image_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
| self.text_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
| def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
| image_embeds = self.qformer(image_embeds=image_embeds) | |
| if self.image_cls_token_type == 'last': | |
| image_embeds = image_embeds[:, -1, :] | |
| else: | |
| raise NotImplementedError | |
| text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask) | |
| text_embeds = text_embeds[:, 0, :] | |
| image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1) | |
| text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1) | |
| contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds, | |
| text_feats=text_embeds, | |
| logit_scale=self.logit_scale) | |
| return { | |
| 'total_loss': contrast_loss, | |
| 'i2t_acc': i2t_acc, | |
| 't2i_acc': t2i_acc, | |
| } | |
| def encode_image_embeds(self, image_embeds): | |
| image_embeds = self.qformer(image_embeds=image_embeds) | |
| return image_embeds | |
| def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs): | |
| model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs) | |
| if pretrained_model_path is not None: | |
| ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
| missing, unexpected = model.load_state_dict(ckpt, strict=False) | |
| print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
| return model | |
| class DiscreteModleStageTwoContrastiveDistill(nn.Module): | |
| def __init__(self, | |
| qformer, | |
| quantizer=None, | |
| distiller=None, | |
| contrast_head=None, | |
| projection_dim=1024, | |
| distill_loss_type='cosine', | |
| freeze_qformer=True, | |
| image_cls_token_type='last', | |
| scale_commit_loss=1.0, | |
| scale_contrast_loss=1.0, | |
| scale_distill_loss=1.0) -> None: | |
| super().__init__() | |
| self.qformer = qformer | |
| self.quantizer = quantizer | |
| self.distiller = distiller | |
| self.contrast_head = contrast_head | |
| self.distill_loss_type = distill_loss_type | |
| self.image_cls_token_type = image_cls_token_type | |
| if self.contrast_head is not None: | |
| self.logit_scale = nn.Parameter(0.07 * torch.ones([])) | |
| self.image_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False) | |
| self.text_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False) | |
| self.freeze_qformer = freeze_qformer | |
| if freeze_qformer: | |
| self.qformer.requires_grad_(False) | |
| self.scale_commit_loss = scale_commit_loss | |
| self.scale_contrast_loss = scale_contrast_loss | |
| self.scale_distill_loss = scale_distill_loss | |
| def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
| if self.freeze_qformer: | |
| with torch.no_grad(): | |
| qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
| else: | |
| qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
| quantizer_output = self.quantizer(qforemr_embeds) | |
| output_state = {} | |
| output_state['indices'] = quantizer_output['indices'] | |
| output_state['commit_loss'] = quantizer_output['commit_loss'] | |
| output_state['total_loss'] = self.scale_commit_loss * quantizer_output['commit_loss'] | |
| if self.distiller is not None: | |
| recon_embeds = self.distiller(quantizer_output['quant_embeds']) | |
| if self.distill_loss_type == 'cosine': | |
| distill_loss = cosine_loss(recon_embeds, image_embeds) | |
| else: | |
| raise NotImplementedError | |
| output_state['distill_loss'] = distill_loss | |
| output_state['total_loss'] += self.scale_distill_loss * distill_loss | |
| if self.contrast_head is not None: | |
| text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask) | |
| text_embeds = text_embeds[:, 0, :] | |
| image_embeds = self.contrast_head(quantizer_output['quant_embeds']) | |
| if self.image_cls_token_type == 'last': | |
| image_embeds = image_embeds[:, -1, :] | |
| else: | |
| raise NotImplementedError | |
| image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1) | |
| text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1) | |
| contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds, | |
| text_feats=text_embeds, | |
| logit_scale=self.logit_scale) | |
| output_state['contrast_loss'] = contrast_loss | |
| output_state['total_loss'] += self.scale_contrast_loss * contrast_loss | |
| output_state['i2t_acc'] = i2t_acc | |
| output_state['t2i_acc'] = t2i_acc | |
| return output_state | |
| def encode_image_embeds(self, image_embeds): | |
| pass | |
| def from_pretrained(cls, qformer, quantizer, distiller=None, contrast_head=None, pretrained_model_path=None, | |
| **kwargs): | |
| model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs) | |
| if pretrained_model_path is not None: | |
| ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
| missing, unexpected = model.load_state_dict(ckpt, strict=False) | |
| print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
| return model | |
| class DiscreteModleDistillWithDoubleContrastive(nn.Module): | |
| def __init__( | |
| self, | |
| qformer, | |
| quantizer=None, | |
| distiller=None, | |
| contrast_head=None, | |
| projection_dim=1024, | |
| distill_loss_type='cosine', | |
| share_contrast_head=True, # share contrastive head with distiller | |
| quantize_cls_token=False, | |
| rec_qformer=False, | |
| has_contrast=False, | |
| freeze_qformer=False, | |
| scale_commit_loss=1.0, | |
| scale_contrast_loss=1.0, | |
| scale_distill_loss=1.0) -> None: | |
| super().__init__() | |
| self.qformer = qformer | |
| self.quantizer = quantizer | |
| self.distiller = distiller | |
| self.contrast_head = contrast_head | |
| self.distill_loss_type = distill_loss_type | |
| self.quantize_cls_token = quantize_cls_token | |
| self.rec_qformer = rec_qformer | |
| self.has_contrast = has_contrast | |
| if freeze_qformer: | |
| self.qformer.requires_grad_(False) | |
| else: | |
| self.logit_scale_qformer = nn.Parameter(0.07 * torch.ones([])) | |
| self.image_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
| self.text_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
| self.cls_norm_qformer = nn.LayerNorm(qformer.perceiver.config.projection_dim) | |
| if self.contrast_head is not None: | |
| self.logit_scale_head = nn.Parameter(0.07 * torch.ones([])) | |
| self.image_proj_head = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False) | |
| self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
| self.cls_norm_head = nn.LayerNorm(contrast_head.perceiver.config.projection_dim) | |
| if share_contrast_head and distiller is not None: | |
| self.logit_scale_head = nn.Parameter(0.07 * torch.ones([])) | |
| self.image_proj_head = nn.Linear(distiller.perceiver.config.projection_dim, projection_dim, bias=False) | |
| self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
| self.cls_norm_head = nn.LayerNorm(distiller.perceiver.config.projection_dim) | |
| self.scale_commit_loss = scale_commit_loss | |
| self.scale_contrast_loss = scale_contrast_loss | |
| self.scale_distill_loss = scale_distill_loss | |
| self.share_contrast_head = share_contrast_head | |
| self.freeze_qformer = freeze_qformer | |
| assert int(self.share_contrast_head) + int(contrast_head is not None) <= 1 | |
| def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
| if self.freeze_qformer: | |
| with torch.no_grad(): | |
| qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
| else: | |
| qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
| qforemr_cls_embeds = qforemr_embeds[:, -1, :] | |
| if not self.quantize_cls_token: | |
| qforemr_embeds = qforemr_embeds[:, :-1, :] | |
| if self.has_contrast: | |
| text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask) | |
| text_cls_embeds = text_embeds[:, 0, :] | |
| output_state = {} | |
| output_state['total_loss'] = 0.0 | |
| if not self.freeze_qformer and self.has_contrast: | |
| qforemr_cls_embeds = self.cls_norm_qformer(qforemr_cls_embeds) | |
| qformer_image_embeds = F.normalize(self.image_proj_qformer(qforemr_cls_embeds), dim=-1) | |
| qformer_text_embeds = F.normalize(self.text_proj_qformer(text_cls_embeds), dim=-1) | |
| qformer_contrast_loss, \ | |
| qformer_i2t_acc, \ | |
| qformer_t2i_acc = contrastive_loss(image_feats=qformer_image_embeds, | |
| text_feats=qformer_text_embeds, | |
| logit_scale=self.logit_scale_qformer) | |
| output_state['qformer_contrast_loss'] = qformer_contrast_loss | |
| output_state['total_loss'] += self.scale_contrast_loss * qformer_contrast_loss | |
| output_state['qformer_i2t_acc'] = qformer_i2t_acc | |
| output_state['qformer_t2i_acc'] = qformer_t2i_acc | |
| if self.quantizer is not None and self.distiller is not None: | |
| quantizer_output = self.quantizer(qforemr_embeds) | |
| recon_embeds = self.distiller(quantizer_output['quant_embeds']) | |
| if self.share_contrast_head: | |
| contrast_head_cls_embeds = recon_embeds[:, -1, :] | |
| contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds) | |
| recon_embeds = recon_embeds[:, :-1, :] | |
| if self.contrast_head is not None: | |
| contrast_head_embeds = self.contrast_head(quantizer_output['quant_embeds']) | |
| contrast_head_cls_embeds = contrast_head_embeds[:, -1, :] | |
| contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds) | |
| output_state['indices'] = quantizer_output['indices'] | |
| output_state['commit_loss'] = quantizer_output['commit_loss'] | |
| output_state['total_loss'] += self.scale_commit_loss * quantizer_output['commit_loss'] | |
| if self.rec_qformer: | |
| target_embeds = qforemr_embeds | |
| else: | |
| target_embeds = image_embeds | |
| if self.distill_loss_type == 'cosine': | |
| distill_loss = cosine_loss(recon_embeds, target_embeds) | |
| else: | |
| raise NotImplementedError | |
| output_state['distill_loss'] = distill_loss | |
| output_state['total_loss'] += self.scale_distill_loss * distill_loss | |
| if self.contrast_head is not None or self.share_contrast_head: | |
| head_image_embeds = F.normalize(self.image_proj_head(contrast_head_cls_embeds), dim=-1) | |
| head_text_embeds = F.normalize(self.text_proj_head(text_cls_embeds), dim=-1) | |
| head_contrast_loss, head_i2t_acc, head_t2i_acc = contrastive_loss(image_feats=head_image_embeds, | |
| text_feats=head_text_embeds, | |
| logit_scale=self.logit_scale_head) | |
| output_state['head_contrast_loss'] = head_contrast_loss | |
| output_state['total_loss'] += self.scale_contrast_loss * head_contrast_loss | |
| output_state['head_i2t_acc'] = head_i2t_acc | |
| output_state['head_t2i_acc'] = head_t2i_acc | |
| return output_state | |
| def encode_image_embeds(self, image_embeds): | |
| qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
| return qforemr_embeds | |
| def from_pretrained(cls, qformer, quantizer=None, distiller=None, contrast_head=None, pretrained_model_path=None, | |
| **kwargs): | |
| model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs) | |
| if pretrained_model_path is not None: | |
| ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
| missing, unexpected = model.load_state_dict(ckpt, strict=False) | |
| print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
| return model | |
| def from_pretrained_stage1_yuying(cls, | |
| qformer, | |
| quantizer=None, | |
| distiller=None, | |
| contrast_head=None, | |
| pretrained_model_path=None, | |
| **kwargs): | |
| model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs) | |
| if pretrained_model_path is not None: | |
| ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
| ckpt = ckpt['model'] | |
| new_ckpt = {} | |
| new_ckpt['qformer.embed_module.query'] = ckpt['query_tokens'].squeeze(0) | |
| new_ckpt['qformer.norm.weight'] = ckpt['ln_vision.weight'] | |
| new_ckpt['qformer.norm.bias'] = ckpt['ln_vision.bias'] | |
| for key in ckpt.keys(): | |
| if key.startswith('Qformer'): | |
| new_key = key.replace('Qformer', 'qformer.perceiver') | |
| new_ckpt[new_key] = ckpt[key] | |
| del ckpt | |
| missing, unexpected = model.load_state_dict(new_ckpt, strict=False) | |
| print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
| print(missing) | |
| print(unexpected) | |
| return model | |