VQualA_GenAI_track2 / modular /modular_model.py
zwx8981's picture
Upload 493 files
a6bc892 verified
from transformers import AutoModel, AutoProcessor
from itertools import product
import torch.nn as nn
import torch
import random
import numpy as np
from collections import defaultdict
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# import clip
qualitys = ['bad', 'poor', 'fair', 'good', 'perfect']
phase = ["overall", "spatial", "alignment", "temporal", "aesthetic"]
class ViTbSiglip(torch.nn.Module):
def __init__(self, feat_len=8, dropout_sp=0.2, dropout_tp=0.2,
dropout_pp=0.2, use_ali=True, spatial='resnet'):
super(ViTbSiglip, self).__init__()
self.siglip_model = AutoModel.from_pretrained(
"google/siglip2-base-patch16-naflex")
self.siglip_processor = AutoProcessor.from_pretrained(
"google/siglip2-base-patch16-naflex")
for param_name, param in self.siglip_model.named_parameters():
if param_name in ['logit_scale', 'logit_bias']:
param.requires_grad = False
self.feat_len = feat_len
self.dropout_sp = dropout_sp
self.dropout_tp = dropout_tp
self.dropout_pp = dropout_pp
self.use_ali = use_ali
self.tr = True
self.sr = True
self.ar = True
if spatial == 'resnet':
self.spatialRec1 = self.spatial_rectifier(
5*256*self.feat_len, self.dropout_sp)
self.spatialRec2 = self.spatial_rectifier(
5*256*self.feat_len, self.dropout_sp)
self.spatialRec3 = self.spatial_rectifier(
5*256*self.feat_len, self.dropout_sp)
self.spatialRec4 = self.spatial_rectifier(
5*256*self.feat_len, self.dropout_sp)
self.spatialRec5 = self.spatial_rectifier(
5*256*self.feat_len, self.dropout_sp)
else:
self.spatialRec1 = self.spatial_rectifier(
1024*self.feat_len, self.dropout_sp)
self.spatialRec2 = self.spatial_rectifier(
1024*self.feat_len, self.dropout_sp)
self.spatialRec3 = self.spatial_rectifier(
1024*self.feat_len, self.dropout_sp)
self.spatialRec4 = self.spatial_rectifier(
1024*self.feat_len, self.dropout_sp)
self.spatialRec5 = self.spatial_rectifier(
1024*self.feat_len, self.dropout_sp)
self.temporalRec1 = self.temporal_rectifier(
(256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048
self.temporalRec2 = self.temporal_rectifier(
(256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048
self.temporalRec3 = self.temporal_rectifier(
(256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048
self.temporalRec4 = self.temporal_rectifier(
(256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048
self.temporalRec5 = self.temporal_rectifier(
(256)*self.feat_len, self.dropout_tp) # Fast:256 Slow:2048
self.mixRec1 = self.mix_rectifier(
768*self.feat_len, self.dropout_pp)
self.mixRec2 = self.mix_rectifier(
768*self.feat_len, self.dropout_pp)
self.mixRec3 = self.mix_rectifier(
768*self.feat_len, self.dropout_pp)
self.mixRec4 = self.mix_rectifier(
768*self.feat_len, self.dropout_pp)
self.mixRec5 = self.mix_rectifier(
768*self.feat_len, self.dropout_pp)
def spatial_rectifier(self, in_channels, dropout_sp):
'''
return batch_size * 2
'''
regression_block = nn.Sequential(
nn.Linear(in_channels, 128),
nn.ReLU(),
nn.Linear(128, 2),
nn.Dropout(p=dropout_sp),
)
return regression_block
def temporal_rectifier(self, in_channels, dropout_tp):
regression_block = nn.Sequential(
nn.Linear(in_channels, 128),
nn.ReLU(),
nn.Linear(128, 2),
nn.Dropout(p=dropout_tp),
)
return regression_block
def mix_rectifier(self, in_channels, dropout_pp):
regression_block = nn.Sequential(
nn.Linear(in_channels, 128),
nn.ReLU(),
nn.Linear(128, 2),
nn.Dropout(p=dropout_pp),
)
return regression_block
def forward(self, x, tem_feat, spa_feat, t2i_feat, prmt):
b_s = len(x)
device = tem_feat.device
pred_y = []
for i in range(b_s):
texts1 = [f"a photo with {s} overall quality, matching {p}"
for s,p in product(qualitys, [prmt[i]])]
texts2 = [f"a photo with {s} spatial quality, matching {p}"
for s,p in product(qualitys, [prmt[i]])]
texts3 = [f"a photo with {s} alignment quality, matching {p}"
for s,p in product(qualitys, [prmt[i]])]
texts4 = [f"a photo with {s} aesthetic quality, matching {p}"
for s,p in product(qualitys, [prmt[i]])]
texts5 = [f"a photo with {s} temporal quality, matching {p}"
for s,p in product(qualitys, [prmt[i]])]
texts = texts1 + texts2 + texts3 + texts4 + texts5
imgs = x[i]
inputs = self.siglip_processor(
text=texts,
images=imgs,
max_num_patches=576,
return_tensors="pt",
max_length=64,
padding="max_length",
truncation=True, # Ensures text is truncated if it exceeds max length
).to(device)
outputs = self.siglip_model(**inputs)
logits_per_image = outputs.logits_per_image
overall_logit = logits_per_image[:, 0:5]
spatial_logit = logits_per_image[:, 5:10]
alignment_logit = logits_per_image[:, 10:15]
aesthetic_logit = logits_per_image[:, 15:20]
temporal_logit = logits_per_image[:, 20:25]
y_pred = torch.softmax(overall_logit, dim=1)
y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4]
y_pred = y_pred.view(1, -1)
y_pred1 = torch.mean(y_pred, dim=1)
y_pred = torch.softmax(spatial_logit, dim=1)
y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4]
y_pred = y_pred.view(1, -1)
y_pred2 = torch.mean(y_pred, dim=1)
y_pred = torch.softmax(alignment_logit, dim=1)
y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4]
y_pred = y_pred.view(1, -1)
y_pred3 = torch.mean(y_pred, dim=1)
y_pred = torch.softmax(temporal_logit, dim=1)
y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4]
y_pred = y_pred.view(1, -1)
y_pred4 = torch.mean(y_pred, dim=1)
y_pred = torch.softmax(aesthetic_logit, dim=1)
y_pred = 1 * y_pred[:, 0] + 2 * y_pred[:, 1] + 3 * y_pred[:, 2] + 4 * y_pred[:, 3] + 5 * y_pred[:, 4]
y_pred = y_pred.view(1, -1)
y_pred5 = torch.mean(y_pred, dim=1)
y_pred = [y_pred1, y_pred2, y_pred3, y_pred4, y_pred5]
y_pred = torch.stack(y_pred).unsqueeze(0)
pred_y.append(y_pred)
pred_y = torch.stack(pred_y).view(b_s,-1)
y_pred1 = pred_y[:,0].unsqueeze(1)
y_pred2 = pred_y[:,1].unsqueeze(1)
y_pred3 = pred_y[:,2].unsqueeze(1)
y_pred4 = pred_y[:,3].unsqueeze(1)
y_pred5 = pred_y[:,4].unsqueeze(1)
ones = torch.ones_like(y_pred1)
# spatial rectifier
if self.sr:
lp_size = spa_feat.shape
spa_feat = spa_feat.view(lp_size[0], -1)
spatial_s = self.spatialRec1(spa_feat)
# ax+b
alphaS1 = torch.chunk(spatial_s, 2, dim=1)[0]
alphaS1 = torch.add(alphaS1, ones)
betaS1 = torch.chunk(spatial_s, 2, dim=1)[1]
spatial_s = self.spatialRec2(spa_feat)
# ax+b
alphaS2 = torch.chunk(spatial_s, 2, dim=1)[0]
alphaS2 = torch.add(alphaS2, ones)
betaS2 = torch.chunk(spatial_s, 2, dim=1)[1]
spatial_s = self.spatialRec3(spa_feat)
# ax+b
alphaS3 = torch.chunk(spatial_s, 2, dim=1)[0]
alphaS3 = torch.add(alphaS3, ones)
betaS3 = torch.chunk(spatial_s, 2, dim=1)[1]
spatial_s = self.spatialRec4(spa_feat)
# ax+b
alphaS4 = torch.chunk(spatial_s, 2, dim=1)[0]
alphaS4 = torch.add(alphaS4, ones)
betaS4 = torch.chunk(spatial_s, 2, dim=1)[1]
spatial_s = self.spatialRec5(spa_feat)
# ax+b
alphaS5 = torch.chunk(spatial_s, 2, dim=1)[0]
alphaS5 = torch.add(alphaS5, ones)
betaS5 = torch.chunk(spatial_s, 2, dim=1)[1]
qs_y1 = torch.add(torch.mul(torch.abs(alphaS1), y_pred1), betaS1).squeeze(1)
qs_y2 = torch.add(torch.mul(torch.abs(alphaS2), y_pred2), betaS2).squeeze(1)
qs_y3 = torch.add(torch.mul(torch.abs(alphaS3), y_pred3), betaS3).squeeze(1)
qs_y4 = torch.add(torch.mul(torch.abs(alphaS4), y_pred4), betaS4).squeeze(1)
qs_y5 = torch.add(torch.mul(torch.abs(alphaS5), y_pred5), betaS5).squeeze(1)
# tempotal rectifier
if self.tr:
x_3D_features_size = tem_feat.shape
tem_feat = tem_feat.view(x_3D_features_size[0], -1)
temporal_s = self.temporalRec1(tem_feat)
# ax+b
alphaT1 = torch.chunk(temporal_s, 2, dim=1)[0]
alphaT1 = torch.add(alphaT1, ones)
betaT1 = torch.chunk(temporal_s, 2, dim=1)[1]
temporal_s = self.temporalRec2(tem_feat)
# ax+b
alphaT2 = torch.chunk(temporal_s, 2, dim=1)[0]
alphaT2 = torch.add(alphaT2, ones)
betaT2 = torch.chunk(temporal_s, 2, dim=1)[1]
temporal_s = self.temporalRec3(tem_feat)
# ax+b
alphaT3 = torch.chunk(temporal_s, 2, dim=1)[0]
alphaT3 = torch.add(alphaT3, ones)
betaT3 = torch.chunk(temporal_s, 2, dim=1)[1]
temporal_s = self.temporalRec4(tem_feat)
# ax+b
alphaT4 = torch.chunk(temporal_s, 2, dim=1)[0]
alphaT4 = torch.add(alphaT4, ones)
betaT4 = torch.chunk(temporal_s, 2, dim=1)[1]
temporal_s = self.temporalRec5(tem_feat)
# ax+b
alphaT5 = torch.chunk(temporal_s, 2, dim=1)[0]
alphaT5 = torch.add(alphaT5, ones)
betaT5 = torch.chunk(temporal_s, 2, dim=1)[1]
qt_y1 = torch.add(torch.mul(torch.abs(alphaT1), y_pred1), betaT1).squeeze(1)
qt_y2 = torch.add(torch.mul(torch.abs(alphaT2), y_pred2), betaT2).squeeze(1)
qt_y3 = torch.add(torch.mul(torch.abs(alphaT3), y_pred3), betaT3).squeeze(1)
qt_y4 = torch.add(torch.mul(torch.abs(alphaT4), y_pred4), betaT4).squeeze(1)
qt_y5 = torch.add(torch.mul(torch.abs(alphaT5), y_pred5), betaT5).squeeze(1)
# alignment rectifier
if self.ar:
t2i_feat_size = t2i_feat.shape
t2i_feat = t2i_feat.view(t2i_feat_size[0], -1)
alignment_s = self.mixRec1(t2i_feat)
# ax+b
alphaA1 = torch.chunk(alignment_s, 2, dim=1)[0]
alphaA1 = torch.add(alphaA1, ones)
betaA1 = torch.chunk(alignment_s, 2, dim=1)[1]
alignment_s = self.mixRec2(t2i_feat)
# ax+b
alphaA2 = torch.chunk(alignment_s, 2, dim=1)[0]
alphaA2 = torch.add(alphaA2, ones)
betaA2 = torch.chunk(alignment_s, 2, dim=1)[1]
alignment_s = self.mixRec3(t2i_feat)
# ax+b
alphaA3 = torch.chunk(alignment_s, 2, dim=1)[0]
alphaA3 = torch.add(alphaA3, ones)
betaA3 = torch.chunk(alignment_s, 2, dim=1)[1]
alignment_s = self.mixRec4(t2i_feat)
# ax+b
alphaA4 = torch.chunk(alignment_s, 2, dim=1)[0]
alphaA4 = torch.add(alphaA4, ones)
betaA4 = torch.chunk(alignment_s, 2, dim=1)[1]
alignment_s = self.mixRec5(t2i_feat)
# ax+b
alphaA5 = torch.chunk(alignment_s, 2, dim=1)[0]
alphaA5 = torch.add(alphaA5, ones)
betaA5 = torch.chunk(alignment_s, 2, dim=1)[1]
qa_y1 = torch.add(torch.mul(torch.abs(alphaA1), y_pred1), betaA1).squeeze(1)
qa_y2 = torch.add(torch.mul(torch.abs(alphaA2), y_pred2), betaA2).squeeze(1)
qa_y3 = torch.add(torch.mul(torch.abs(alphaA3), y_pred3), betaA3).squeeze(1)
qa_y4 = torch.add(torch.mul(torch.abs(alphaA4), y_pred4), betaA4).squeeze(1)
qa_y5 = torch.add(torch.mul(torch.abs(alphaA5), y_pred5), betaA5).squeeze(1)
if self.sr and self.tr and self.ar:
sta_a1 = torch.pow(
torch.abs(torch.mul(torch.mul(alphaS1, alphaT1), alphaA1)), 1/3)
sta_b1 = torch.div(torch.add(torch.add(betaS1, betaT1), betaA1), 3)
qsta_y1 = torch.add(torch.mul(sta_a1, y_pred1), sta_b1).squeeze(1)
sta_a2 = torch.pow(
torch.abs(torch.mul(torch.mul(alphaS2, alphaT2), alphaA2)), 1/3)
sta_b2 = torch.div(torch.add(torch.add(betaS2, betaT2), betaA2), 3)
qsta_y2 = torch.add(torch.mul(sta_a2, y_pred2), sta_b2).squeeze(1)
sta_a3 = torch.pow(
torch.abs(torch.mul(torch.mul(alphaS3, alphaT3), alphaA3)), 1/3)
sta_b3 = torch.div(torch.add(torch.add(betaS3, betaT3), betaA3), 3)
qsta_y3 = torch.add(torch.mul(sta_a3, y_pred3), sta_b3).squeeze(1)
sta_a4 = torch.pow(
torch.abs(torch.mul(torch.mul(alphaS4, alphaT4), alphaA4)), 1/3)
sta_b4 = torch.div(torch.add(torch.add(betaS4, betaT4), betaA4), 3)
qsta_y4 = torch.add(torch.mul(sta_a4, y_pred4), sta_b4).squeeze(1)
sta_a5 = torch.pow(
torch.abs(torch.mul(torch.mul(alphaS5, alphaT5), alphaA5)), 1/3)
sta_b5 = torch.div(torch.add(torch.add(betaS5, betaT5), betaA5), 3)
qsta_y5 = torch.add(torch.mul(sta_a5, y_pred5), sta_b5).squeeze(1)
ta_a1 = torch.pow(
torch.abs(torch.mul(alphaT1, alphaA1)), 1/2)
ta_b1 = torch.div(torch.add(betaT1, betaA1), 2)
qta_y1 = torch.add(torch.mul(ta_a1, y_pred1), ta_b1).squeeze(1)
ta_a2 = torch.pow(
torch.abs(torch.mul(alphaT2, alphaA2)), 1/2)
ta_b2 = torch.div(torch.add(betaT2, betaA2), 2)
qta_y2 = torch.add(torch.mul(ta_a2, y_pred2), ta_b2).squeeze(1)
ta_a3 = torch.pow(
torch.abs(torch.mul(alphaT3, alphaA3)), 1/2)
ta_b3 = torch.div(torch.add(betaT3, betaA3), 2)
qta_y3 = torch.add(torch.mul(ta_a3, y_pred3), ta_b3).squeeze(1)
ta_a4 = torch.pow(
torch.abs(torch.mul(alphaT4, alphaA4)), 1/2)
ta_b4 = torch.div(torch.add(betaT4, betaA4), 2)
qta_y4 = torch.add(torch.mul(ta_a4, y_pred4), ta_b1).squeeze(1)
ta_a5 = torch.pow(
torch.abs(torch.mul(alphaT5, alphaA5)), 1/2)
ta_b5 = torch.div(torch.add(betaT5, betaA5), 2)
qta_y5 = torch.add(torch.mul(ta_a5, y_pred5), ta_b1).squeeze(1)
y_final1 = torch.stack((y_pred1.squeeze(1), qt_y1, qs_y1, qa_y1, qsta_y1, qta_y1, qa_y1), dim=1)
y_final2 = torch.stack((y_pred2.squeeze(1), qt_y2, qs_y2, qa_y2, qsta_y2, qta_y2, qa_y2), dim=1)
y_final3 = torch.stack((y_pred3.squeeze(1), qt_y3, qs_y3, qa_y3, qsta_y3, qta_y3, qa_y3), dim=1)
y_final4 = torch.stack((y_pred4.squeeze(1), qt_y4, qs_y4, qa_y4, qsta_y4, qta_y4, qa_y4), dim=1)
y_final5 = torch.stack((y_pred5.squeeze(1), qt_y5, qs_y5, qa_y5, qsta_y5, qta_y5, qa_y5), dim=1)
# if batch_size == 1, then return shape[4] directly
if (b_s == 1):
y_final1 = y_final1.squeeze(1).to('cpu')
y_final2 = y_final2.squeeze(1).to('cpu')
y_final3 = y_final3.squeeze(1).to('cpu')
y_final4 = y_final4.squeeze(1).to('cpu')
y_final5 = y_final5.squeeze(1).to('cpu')
return torch.stack([y_final1, y_final2, y_final3, y_final4, y_final5],dim=2) # bs*5(preds)*5(dimensions)