DiffICM / 1_feature_extractor /models_proteus_clip.py
Qiyp's picture
code of stage1 & 3, remove large files
1633fcc
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed.nn
import torch.distributed as dist
from torch.nn.init import trunc_normal_
from torch.nn.utils import weight_norm
import models_clip
class MetaArch(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
student_model_dict = dict()
teacher_model_dict = dict()
import_student = getattr(models_clip, cfg.target_model)
student = import_student()
embed_dim = student.embed_dim
import_teacher = getattr(models_clip, cfg.teacher_model)
teacher_backbone = import_teacher(teacher_path=cfg.teacher_path)
teacher_backbone.eval()
student_model_dict['backbone'] = student
teacher_model_dict['backbone'] = teacher_backbone
self.embed_dim = embed_dim
# initialize parameters and checks
self.total_n_global_crops = cfg.batch_size
self.student = nn.ModuleDict(student_model_dict)
self.teacher = nn.ModuleDict(teacher_model_dict)
teacher_embed_dim = teacher_backbone.embed_dim
self.token_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, teacher_embed_dim))
self.soft_criterion = torch.nn.MSELoss()
for param in self.teacher.backbone.parameters():
param.requires_grad = False
## we explicitly remove the patch and feature learning objectives for CLIP training following the original design
def forward(self, inputs):
global_crops = inputs["collated_global_crops"]
# compute teacher output
# @torch.no_grad()
def compute_teacher_output():
with torch.no_grad():
teacher_backbone_output_dict = self.teacher.backbone(global_crops)
teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
return teacher_cls_tokens
# get the teacher outputs
teacher_cls_tokens = compute_teacher_output()
student_backbone_output_dict_unmask = self.student.backbone(global_crops)
student_cls_token_unmask = student_backbone_output_dict_unmask["x_norm_clstoken"]
## projection head
student_cls_token_unmask = self.token_head(student_cls_token_unmask)
## token objective
distillation_loss_token = self.soft_criterion(student_cls_token_unmask, teacher_cls_tokens)
# coefficient
token_loss = self.cfg.lambda_token * distillation_loss_token
# compute the total loss
total_loss = token_loss
# return the final loss dict
loss_dict = {"patch_loss": torch.tensor(0.0), "fea_loss": torch.tensor(0.0), "token_loss": token_loss, "loss": total_loss}
return loss_dict