from transformers import AutoModel, AutoProcessor from itertools import product import torch.nn as nn import torch import random import numpy as np from collections import defaultdict import os os.environ["TOKENIZERS_PARALLELISM"] = "false" # import clip qualitys = ['bad', 'poor', 'fair', 'good', 'perfect'] phase = ["overall", "spatial", "alignment", "temporal", "aesthetic"] class ViTbSiglip(torch.nn.Module): def __init__(self, feat_len=8, dropout_sp=0.2, dropout_tp=0.2, dropout_pp=0.2, use_ali=True, spatial='resnet'): super(ViTbSiglip, self).__init__() self.siglip_model = AutoModel.from_pretrained( "google/siglip2-base-patch16-naflex") self.siglip_processor = AutoProcessor.from_pretrained( "google/siglip2-base-patch16-naflex") for param_name, param in self.siglip_model.named_parameters(): if param_name in ['logit_scale', 'logit_bias']: param.requires_grad = False self.feat_len = feat_len self.dropout_sp = dropout_sp self.dropout_tp = dropout_tp self.dropout_pp = dropout_pp self.use_ali = use_ali self.tr = True self.sr = True self.ar = True if spatial == 'resnet': self.spatialRec1 = self.spatial_rectifier( 5*256*self.feat_len, self.dropout_sp) self.spatialRec2 = self.spatial_rectifier( 5*256*self.feat_len, self.dropout_sp) self.spatialRec3 = self.spatial_rectifier( 5*256*self.feat_len, self.dropout_sp) self.spatialRec4 = self.spatial_rectifier( 5*256*self.feat_len, self.dropout_sp) self.spatialRec5 = self.spatial_rectifier( 5*256*self.feat_len, self.dropout_sp) else: self.spatialRec1 = self.spatial_rectifier( 1024*self.feat_len, self.dropout_sp) self.spatialRec2 = self.spatial_rectifier( 1024*self.feat_len, self.dropout_sp) self.spatialRec3 = self.spatial_rectifier( 1024*self.feat_len, self.dropout_sp) self.spatialRec4 = self.spatial_rectifier( 1024*self.feat_len, self.dropout_sp) self.spatialRec5 = self.spatial_rectifier( 1024*self.feat_len, self.dropout_sp) self.temporalRec1 = self.temporal_rectifier( (256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048 self.temporalRec2 = self.temporal_rectifier( (256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048 self.temporalRec3 = self.temporal_rectifier( (256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048 self.temporalRec4 = self.temporal_rectifier( (256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048 self.temporalRec5 = self.temporal_rectifier( (256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048 self.mixRec1 = self.mix_rectifier( 768*self.feat_len, self.dropout_pp) self.mixRec2 = self.mix_rectifier( 768*self.feat_len, self.dropout_pp) self.mixRec3 = self.mix_rectifier( 768*self.feat_len, self.dropout_pp) self.mixRec4 = self.mix_rectifier( 768*self.feat_len, self.dropout_pp) self.mixRec5 = self.mix_rectifier( 768*self.feat_len, self.dropout_pp) def spatial_rectifier(self, in_channels, dropout_sp): ''' return batch_size * 2 ''' regression_block = nn.Sequential( nn.Linear(in_channels, 128), nn.ReLU(), nn.Linear(128, 2), nn.Dropout(p=dropout_sp), ) return regression_block def temporal_rectifier(self, in_channels, dropout_tp): regression_block = nn.Sequential( nn.Linear(in_channels, 128), nn.ReLU(), nn.Linear(128, 2), nn.Dropout(p=dropout_tp), ) return regression_block def mix_rectifier(self, in_channels, dropout_pp): regression_block = nn.Sequential( nn.Linear(in_channels, 128), nn.ReLU(), nn.Linear(128, 2), nn.Dropout(p=dropout_pp), ) return regression_block def forward(self, x, tem_feat, spa_feat, t2i_feat, prmt): b_s = len(x) device = tem_feat.device pred_y = [] for i in range(b_s): texts1 = [f"a photo with {s} overall quality, matching {p}" for s,p in product(qualitys, [prmt[i]])] texts2 = [f"a photo with {s} spatial quality, matching {p}" for s,p in product(qualitys, [prmt[i]])] texts3 = [f"a photo with {s} alignment quality, matching {p}" for s,p in product(qualitys, [prmt[i]])] texts4 = [f"a photo with {s} aesthetic quality, matching {p}" for s,p in product(qualitys, [prmt[i]])] texts5 = [f"a photo with {s} temporal quality, matching {p}" for s,p in product(qualitys, [prmt[i]])] texts = texts1 + texts2 + texts3 + texts4 + texts5 imgs = x[i] inputs = self.siglip_processor( text=texts, images=imgs, max_num_patches=576, return_tensors="pt", max_length=64, padding="max_length", truncation=True, # Ensures text is truncated if it exceeds max length ).to(device) outputs = self.siglip_model(**inputs) logits_per_image = outputs.logits_per_image overall_logit = logits_per_image[:, 0:5] spatial_logit = logits_per_image[:, 5:10] alignment_logit = logits_per_image[:, 10:15] aesthetic_logit = logits_per_image[:, 15:20] temporal_logit = logits_per_image[:, 20:25] y_pred = torch.softmax(overall_logit, dim=1) y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4] y_pred = y_pred.view(1, -1) y_pred1 = torch.mean(y_pred, dim=1) y_pred = torch.softmax(spatial_logit, dim=1) y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4] y_pred = y_pred.view(1, -1) y_pred2 = torch.mean(y_pred, dim=1) y_pred = torch.softmax(alignment_logit, dim=1) y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4] y_pred = y_pred.view(1, -1) y_pred3 = torch.mean(y_pred, dim=1) y_pred = torch.softmax(temporal_logit, dim=1) y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4] y_pred = y_pred.view(1, -1) y_pred4 = torch.mean(y_pred, dim=1) y_pred = torch.softmax(aesthetic_logit, dim=1) y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4] y_pred = y_pred.view(1, -1) y_pred5 = torch.mean(y_pred, dim=1) y_pred = [y_pred1, y_pred2, y_pred3, y_pred4, y_pred5] y_pred = torch.stack(y_pred).unsqueeze(0) pred_y.append(y_pred) pred_y = torch.stack(pred_y).view(b_s,-1) y_pred1 = pred_y[:,0].unsqueeze(1) y_pred2 = pred_y[:,1].unsqueeze(1) y_pred3 = pred_y[:,2].unsqueeze(1) y_pred4 = pred_y[:,3].unsqueeze(1) y_pred5 = pred_y[:,4].unsqueeze(1) ones = torch.ones_like(y_pred1) # spatial rectifier if self.sr: lp_size = spa_feat.shape spa_feat = spa_feat.view(lp_size[0], -1) spatial_s = self.spatialRec1(spa_feat) # ax+b alphaS1 = torch.chunk(spatial_s, 2, dim=1)[0] alphaS1 = torch.add(alphaS1, ones) betaS1 = torch.chunk(spatial_s, 2, dim=1)[1] spatial_s = self.spatialRec2(spa_feat) # ax+b alphaS2 = torch.chunk(spatial_s, 2, dim=1)[0] alphaS2 = torch.add(alphaS2, ones) betaS2 = torch.chunk(spatial_s, 2, dim=1)[1] spatial_s = self.spatialRec3(spa_feat) # ax+b alphaS3 = torch.chunk(spatial_s, 2, dim=1)[0] alphaS3 = torch.add(alphaS3, ones) betaS3 = torch.chunk(spatial_s, 2, dim=1)[1] spatial_s = self.spatialRec4(spa_feat) # ax+b alphaS4 = torch.chunk(spatial_s, 2, dim=1)[0] alphaS4 = torch.add(alphaS4, ones) betaS4 = torch.chunk(spatial_s, 2, dim=1)[1] spatial_s = self.spatialRec5(spa_feat) # ax+b alphaS5 = torch.chunk(spatial_s, 2, dim=1)[0] alphaS5 = torch.add(alphaS5, ones) betaS5 = torch.chunk(spatial_s, 2, dim=1)[1] qs_y1 = torch.add(torch.mul(torch.abs(alphaS1), y_pred1), betaS1).squeeze(1) qs_y2 = torch.add(torch.mul(torch.abs(alphaS2), y_pred2), betaS2).squeeze(1) qs_y3 = torch.add(torch.mul(torch.abs(alphaS3), y_pred3), betaS3).squeeze(1) qs_y4 = torch.add(torch.mul(torch.abs(alphaS4), y_pred4), betaS4).squeeze(1) qs_y5 = torch.add(torch.mul(torch.abs(alphaS5), y_pred5), betaS5).squeeze(1) # tempotal rectifier if self.tr: x_3D_features_size = tem_feat.shape tem_feat = tem_feat.view(x_3D_features_size[0], -1) temporal_s = self.temporalRec1(tem_feat) # ax+b alphaT1 = torch.chunk(temporal_s, 2, dim=1)[0] alphaT1 = torch.add(alphaT1, ones) betaT1 = torch.chunk(temporal_s, 2, dim=1)[1] temporal_s = self.temporalRec2(tem_feat) # ax+b alphaT2 = torch.chunk(temporal_s, 2, dim=1)[0] alphaT2 = torch.add(alphaT2, ones) betaT2 = torch.chunk(temporal_s, 2, dim=1)[1] temporal_s = self.temporalRec3(tem_feat) # ax+b alphaT3 = torch.chunk(temporal_s, 2, dim=1)[0] alphaT3 = torch.add(alphaT3, ones) betaT3 = torch.chunk(temporal_s, 2, dim=1)[1] temporal_s = self.temporalRec4(tem_feat) # ax+b alphaT4 = torch.chunk(temporal_s, 2, dim=1)[0] alphaT4 = torch.add(alphaT4, ones) betaT4 = torch.chunk(temporal_s, 2, dim=1)[1] temporal_s = self.temporalRec5(tem_feat) # ax+b alphaT5 = torch.chunk(temporal_s, 2, dim=1)[0] alphaT5 = torch.add(alphaT5, ones) betaT5 = torch.chunk(temporal_s, 2, dim=1)[1] qt_y1 = torch.add(torch.mul(torch.abs(alphaT1), y_pred1), betaT1).squeeze(1) qt_y2 = torch.add(torch.mul(torch.abs(alphaT2), y_pred2), betaT2).squeeze(1) qt_y3 = torch.add(torch.mul(torch.abs(alphaT3), y_pred3), betaT3).squeeze(1) qt_y4 = torch.add(torch.mul(torch.abs(alphaT4), y_pred4), betaT4).squeeze(1) qt_y5 = torch.add(torch.mul(torch.abs(alphaT5), y_pred5), betaT5).squeeze(1) # alignment rectifier if self.ar: t2i_feat_size = t2i_feat.shape t2i_feat = t2i_feat.view(t2i_feat_size[0], -1) alignment_s = self.mixRec1(t2i_feat) # ax+b alphaA1 = torch.chunk(alignment_s, 2, dim=1)[0] alphaA1 = torch.add(alphaA1, ones) betaA1 = torch.chunk(alignment_s, 2, dim=1)[1] alignment_s = self.mixRec2(t2i_feat) # ax+b alphaA2 = torch.chunk(alignment_s, 2, dim=1)[0] alphaA2 = torch.add(alphaA2, ones) betaA2 = torch.chunk(alignment_s, 2, dim=1)[1] alignment_s = self.mixRec3(t2i_feat) # ax+b alphaA3 = torch.chunk(alignment_s, 2, dim=1)[0] alphaA3 = torch.add(alphaA3, ones) betaA3 = torch.chunk(alignment_s, 2, dim=1)[1] alignment_s = self.mixRec4(t2i_feat) # ax+b alphaA4 = torch.chunk(alignment_s, 2, dim=1)[0] alphaA4 = torch.add(alphaA4, ones) betaA4 = torch.chunk(alignment_s, 2, dim=1)[1] alignment_s = self.mixRec5(t2i_feat) # ax+b alphaA5 = torch.chunk(alignment_s, 2, dim=1)[0] alphaA5 = torch.add(alphaA5, ones) betaA5 = torch.chunk(alignment_s, 2, dim=1)[1] qa_y1 = torch.add(torch.mul(torch.abs(alphaA1), y_pred1), betaA1).squeeze(1) qa_y2 = torch.add(torch.mul(torch.abs(alphaA2), y_pred2), betaA2).squeeze(1) qa_y3 = torch.add(torch.mul(torch.abs(alphaA3), y_pred3), betaA3).squeeze(1) qa_y4 = torch.add(torch.mul(torch.abs(alphaA4), y_pred4), betaA4).squeeze(1) qa_y5 = torch.add(torch.mul(torch.abs(alphaA5), y_pred5), betaA5).squeeze(1) if self.sr and self.tr and self.ar: sta_a1 = torch.pow( torch.abs(torch.mul(torch.mul(alphaS1, alphaT1), alphaA1)), 1/3) sta_b1 = torch.div(torch.add(torch.add(betaS1, betaT1), betaA1), 3) qsta_y1 = torch.add(torch.mul(sta_a1, y_pred1), sta_b1).squeeze(1) sta_a2 = torch.pow( torch.abs(torch.mul(torch.mul(alphaS2, alphaT2), alphaA2)), 1/3) sta_b2 = torch.div(torch.add(torch.add(betaS2, betaT2), betaA2), 3) qsta_y2 = torch.add(torch.mul(sta_a2, y_pred2), sta_b2).squeeze(1) sta_a3 = torch.pow( torch.abs(torch.mul(torch.mul(alphaS3, alphaT3), alphaA3)), 1/3) sta_b3 = torch.div(torch.add(torch.add(betaS3, betaT3), betaA3), 3) qsta_y3 = torch.add(torch.mul(sta_a3, y_pred3), sta_b3).squeeze(1) sta_a4 = torch.pow( torch.abs(torch.mul(torch.mul(alphaS4, alphaT4), alphaA4)), 1/3) sta_b4 = torch.div(torch.add(torch.add(betaS4, betaT4), betaA4), 3) qsta_y4 = torch.add(torch.mul(sta_a4, y_pred4), sta_b4).squeeze(1) sta_a5 = torch.pow( torch.abs(torch.mul(torch.mul(alphaS5, alphaT5), alphaA5)), 1/3) sta_b5 = torch.div(torch.add(torch.add(betaS5, betaT5), betaA5), 3) qsta_y5 = torch.add(torch.mul(sta_a5, y_pred5), sta_b5).squeeze(1) ta_a1 = torch.pow( torch.abs(torch.mul(alphaT1, alphaA1)), 1/2) ta_b1 = torch.div(torch.add(betaT1, betaA1), 2) qta_y1 = torch.add(torch.mul(ta_a1, y_pred1), ta_b1).squeeze(1) ta_a2 = torch.pow( torch.abs(torch.mul(alphaT2, alphaA2)), 1/2) ta_b2 = torch.div(torch.add(betaT2, betaA2), 2) qta_y2 = torch.add(torch.mul(ta_a2, y_pred2), ta_b2).squeeze(1) ta_a3 = torch.pow( torch.abs(torch.mul(alphaT3, alphaA3)), 1/2) ta_b3 = torch.div(torch.add(betaT3, betaA3), 2) qta_y3 = torch.add(torch.mul(ta_a3, y_pred3), ta_b3).squeeze(1) ta_a4 = torch.pow( torch.abs(torch.mul(alphaT4, alphaA4)), 1/2) ta_b4 = torch.div(torch.add(betaT4, betaA4), 2) qta_y4 = torch.add(torch.mul(ta_a4, y_pred4), ta_b1).squeeze(1) ta_a5 = torch.pow( torch.abs(torch.mul(alphaT5, alphaA5)), 1/2) ta_b5 = torch.div(torch.add(betaT5, betaA5), 2) qta_y5 = torch.add(torch.mul(ta_a5, y_pred5), ta_b1).squeeze(1) y_final1 = torch.stack((y_pred1.squeeze(1), qt_y1, qs_y1, qa_y1, qsta_y1, qta_y1, qa_y1), dim=1) y_final2 = torch.stack((y_pred2.squeeze(1), qt_y2, qs_y2, qa_y2, qsta_y2, qta_y2, qa_y2), dim=1) y_final3 = torch.stack((y_pred3.squeeze(1), qt_y3, qs_y3, qa_y3, qsta_y3, qta_y3, qa_y3), dim=1) y_final4 = torch.stack((y_pred4.squeeze(1), qt_y4, qs_y4, qa_y4, qsta_y4, qta_y4, qa_y4), dim=1) y_final5 = torch.stack((y_pred5.squeeze(1), qt_y5, qs_y5, qa_y5, qsta_y5, qta_y5, qa_y5), dim=1) # if batch_size == 1, then return shape[4] directly if (b_s == 1): y_final1 = y_final1.squeeze(1).to('cpu') y_final2 = y_final2.squeeze(1).to('cpu') y_final3 = y_final3.squeeze(1).to('cpu') y_final4 = y_final4.squeeze(1).to('cpu') y_final5 = y_final5.squeeze(1).to('cpu') return torch.stack([y_final1, y_final2, y_final3, y_final4, y_final5],dim=2) # bs*5(preds)*5(dimensions)