yermandy's picture
init
c29babb
import torch.nn.functional as F
from torch import nn
from .adapters.adapter import Adapter
from .attn import RecAttnClip
from .clip.clip import load
from .layer import MaskPostXrayProcess, PostClipProcess
class DS(nn.Module):
def __init__(
self, clip_name, adapter_vit_name, num_quires, fusion_map, mlp_dim, mlp_out_dim, head_num, mode="video"
):
super().__init__()
self.clip_model, self.processor = load(clip_name, download_root="weights/forensics_adapter")
self.adapter = Adapter(
vit_name=adapter_vit_name,
num_quires=num_quires,
fusion_map=fusion_map,
mlp_dim=mlp_dim,
mlp_out_dim=mlp_out_dim,
head_num=head_num,
)
self.rec_attn_clip = RecAttnClip(self.clip_model.visual, num_quires) # 全部参数被冻结
self.masked_xray_post_process = MaskPostXrayProcess(in_c=num_quires)
self.clip_post_process = PostClipProcess(num_quires=num_quires, embed_dim=768)
self.mode = mode
self._freeze()
def _freeze(self):
for name, param in self.named_parameters():
if "clip_model" in name:
param.requires_grad = False
def get_losses(self, data_dict, pred_dict):
label = data_dict["label"] # N
xray = data_dict["xray"]
pred = pred_dict["cls"] # N2
xray_pred = pred_dict["xray_pred"]
loss_intra = pred_dict["loss_intra"]
loss_clip = pred_dict["loss_clip"]
criterion = nn.CrossEntropyLoss()
loss1 = criterion(pred.float(), label)
if xray is not None:
loss_mse = F.mse_loss(xray_pred.squeeze().float(), xray.squeeze().float()) # (N 1 224 224)->(N 224 224)
loss = 10 * loss1 + 200 * loss_mse + 20 * loss_intra + 10 * loss_clip
loss_dict = {"cls": loss1, "xray": loss_mse, "intra": loss_intra, "loss_clip": loss_clip, "overall": loss}
return loss_dict
else:
loss_dict = {"cls": loss1, "overall": loss1}
return loss_dict
def forward(self, data_dict, inference=False):
images = data_dict["image"]
clip_images = F.interpolate(
images,
size=(224, 224),
mode="bilinear",
align_corners=False,
)
clip_features = self.clip_model.extract_features(clip_images, self.adapter.fusion_map.values())
attn_biases, xray_preds, loss_adapter_intra = self.adapter(data_dict, clip_features, inference)
clip_output, loss_clip = self.rec_attn_clip(
data_dict, clip_features, attn_biases[-1], inference, normalize=True
)
# data_dict["if_boundary"] = data_dict["if_boundary"].to(self.device)
# xray_preds = [self.masked_xray_post_process(xray_pred, data_dict["if_boundary"]) for xray_pred in xray_preds]
clip_cls_output = self.clip_post_process(clip_output.float()) # N2
# prob = torch.softmax(outputs["clip_cls_output"], dim=1)[:, 1]
pred_dict = {
"logits": clip_cls_output,
# "cls": outputs["clip_cls_output"],
# "prob": prob,
# "xray_pred": xray_preds[-1], # N 1 224 224
# "loss_intra": loss_adapter_intra,
# "loss_clip": loss_clip,
}
return pred_dict