File size: 11,405 Bytes
383998a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 | import numpy as np
import timm
import torch
import torchvision
import torch.nn as nn
from transformers import CLIPVisionModel
# import open_clip
import torchvision.transforms as transforms
from PIL import Image
import cv2
import accelerate
def load_clip_model(clip_model="openai/ViT-B-16", clip_freeze=True, precision='fp16'):
pretrained, model_tag = clip_model.split('/')
pretrained = None if pretrained == 'None' else pretrained
# clip_model = open_clip.create_model(model_tag, precision=precision, pretrained=pretrained)
# clip_model = timm.create_model('timm/vit_base_patch16_clip_224.openai', pretrained=True, in_chans=3)
clip_model = CLIPVisionModel.from_pretrained(clip_model)
if clip_freeze:
for param in clip_model.parameters():
param.requires_grad = False
if model_tag == 'clip-vit-base-patch16':
feature_size = dict(global_feature=768, local_feature=[196, 768])
elif model_tag == 'ViT-L-14-quickgelu' or model_tag == 'ViT-L-14':
feature_size = dict(global_feature=768, local_feature=[256, 1024])
else:
raise ValueError(f"Unknown model_tag: {model_tag}")
return clip_model, feature_size
class DualBranch(nn.Module):
def __init__(self, clip_model="openai/clip-vit-base-patch16", clip_freeze=True, precision='fp16'):
super(DualBranch, self).__init__()
self.clip_freeze = clip_freeze
# Load CLIP model
self.clip_model, feature_size = load_clip_model(clip_model, clip_freeze, precision)
# Initialize CLIP vision model for task classification
self.task_cls_clip = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16")
self.head = nn.Linear(feature_size['global_feature']*3, 1)
self.compare_head =nn.Linear(feature_size['global_feature']*6, 3)
self.prompt = nn.Parameter(torch.rand(1, feature_size['global_feature']))
self.task_mlp = nn.Sequential(
nn.Linear(feature_size['global_feature'], feature_size['global_feature']),
nn.SiLU(False),
nn.Linear(feature_size['global_feature'], feature_size['global_feature']))
self.prompt_mlp = nn.Linear(feature_size['global_feature'], feature_size['global_feature'])
with torch.no_grad():
self.task_mlp[0].weight.fill_(0.0)
self.task_mlp[0].bias.fill_(0.0)
self.task_mlp[2].weight.fill_(0.0)
self.task_mlp[2].bias.fill_(0.0)
self.prompt_mlp.weight.fill_(0.0)
self.prompt_mlp.bias.fill_(0.0)
# Load pre-trained weights
self._load_pretrained_weights("./weights/Degradation.pth")
for param in self.task_cls_clip.parameters():
param.requires_grad = False
# Unfreeze the last two layers
for i in range(10, 12): # Layers 10 and 11
for param in self.task_cls_clip.vision_model.encoder.layers[i].parameters():
param.requires_grad = True
def _load_pretrained_weights(self, state_dict_path):
"""
Load pre-trained weights, including the CLIP model and classification head.
"""
# Load state dictionary
state_dict = torch.load(state_dict_path)
# Separate weights for CLIP model and classification head
clip_state_dict = {}
for key, value in state_dict.items():
if key.startswith('clip_model.'):
# Remove 'clip_model.' prefix for the CLIP model
new_key = key.replace('clip_model.', '')
clip_state_dict[new_key] = value
# elif key in ['head.weight', 'head.bias']:
# # Save weights for the classification head
# head_state_dict[key] = value
# Load weights for the CLIP model
self.task_cls_clip.load_state_dict(clip_state_dict, strict=False)
print("Successfully loaded CLIP model weights")
def forward(self, x0, x1 = None):
# features, _ = self.clip_model.encode_image(x)
if x1 is None:
# Image features
features0 = self.clip_model(x0)['pooler_output']
# Classification features
task_features0 = self.task_cls_clip(x0)['pooler_output']
# Learn classification features
task_embedding = torch.softmax(self.task_mlp(task_features0), dim=1) * self.prompt
task_embedding = self.prompt_mlp(task_embedding)
# features = torch.cat([features0, task_features], dim
features0 = torch.cat([features0, task_embedding, features0+task_embedding], dim=1)
quality = self.head(features0)
quality = nn.Sigmoid()(quality)
return quality, None, None
elif x1 is not None:
# features_, _ = self.clip_model.encode_image(x_local)
# Image features
features0 = self.clip_model(x0)['pooler_output']
features1 = self.clip_model(x1)['pooler_output']
# Classification features
task_features0 = self.task_cls_clip(x0)['pooler_output']
task_features1 = self.task_cls_clip(x1)['pooler_output']
task_embedding0 = torch.softmax(self.task_mlp(task_features0), dim=1) * self.prompt
task_embedding0 = self.prompt_mlp(task_embedding0)
task_embedding1 = torch.softmax(self.task_mlp(task_features1), dim=1) * self.prompt
task_embedding1 = self.prompt_mlp(task_embedding1)
features0 = torch.cat([features0, task_embedding0, features0+task_embedding0], dim=1)
features1 = torch.cat([features1, task_embedding1, features1+task_embedding1], dim=1)
# features0 = torch.cat([features0, task_features0], dim=
# import pdb; pdb.set_trace()
features = torch.cat([features0, features1], dim=1)
# features = torch.cat([features0, features1], dim=1)
compare_quality = self.compare_head(features)
# quality0 = self.head(features0)
# quality1 = self.head(features1)
quality0 = self.head(features0)
quality1 = self.head(features1)
quality0 = nn.Sigmoid()(quality0)
quality1 = nn.Sigmoid()(quality1)
# quality = {'quality0': quality0, 'quality1': quality1}
return quality0, quality1, compare_quality
class FGResQ:
def __init__(self, model_path, clip_model="openai/clip-vit-base-patch16", input_size=224, device=None):
"""
Initializes the inference model.
Args:
model_path (str): Path to the pre-trained model checkpoint (.pth or .safetensors).
clip_model (str): Name of the CLIP model to use.
input_size (int): Input image size for the model.
device (str, optional): Device to run inference on ('cuda' or 'cpu'). Auto-detected if None.
"""
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
print(f"Using device: {self.device}")
# Load the model
self.model = DualBranch(clip_model=clip_model, clip_freeze=True, precision='fp32')
# self.model = self.accelerator.unwrap_model(self.model)
# Load model weights
try:
raw = torch.load(model_path, map_location=self.device)
# unwrap possible containers
if isinstance(raw, dict) and any(k in raw for k in ['model', 'state_dict']):
state_dict = raw.get('model', raw.get('state_dict', raw))
else:
state_dict = raw
# Only strip 'module.' if present; keep other namespaces intact
if any(k.startswith('module.') for k in state_dict.keys()):
state_dict = {k.replace('module.', '', 1): v for k, v in state_dict.items()}
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
if missing:
print(f"[load_state_dict] Missing keys: {missing}")
if unexpected:
print(f"[load_state_dict] Unexpected keys: {unexpected}")
print(f"Model weights loaded from {model_path}")
except Exception as e:
print(f"Error loading model weights: {e}")
raise
self.model.to(self.device)
self.model.eval()
# Define image preprocessing
# Match training/validation pipeline: first unify to 256x256 (as in cls_model/dataset.py),
# then CenterCrop to input_size, followed by CLIP normalization.
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.CenterCrop(input_size),
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711]
)
])
def _preprocess_image(self, image_path):
"""Load and preprocess a single image."""
try:
# Match training dataset loader: cv2 read + resize to 256x256 (INTER_LINEAR)
img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
if img is None:
raise FileNotFoundError(f"Failed to read image at {image_path}")
img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_LINEAR)
image = Image.fromarray(img)
image_tensor = self.transform(image).unsqueeze(0)
return image_tensor.to(self.device)
except FileNotFoundError:
print(f"Error: Image file not found at {image_path}")
return None
except Exception as e:
print(f"Error processing image {image_path}: {e}")
return None
@torch.no_grad()
def predict_single(self, image_path):
"""
Predict the quality score of a single image.
"""
image_tensor = self._preprocess_image(image_path)
if image_tensor is None:
return None
quality_score, _, _ = self.model(image_tensor)
return quality_score.squeeze().item()
@torch.no_grad()
def predict_pair(self, image_path1, image_path2):
"""
Compare the quality of two images.
"""
image_tensor1 = self._preprocess_image(image_path1)
image_tensor2 = self._preprocess_image(image_path2)
if image_tensor1 is None or image_tensor2 is None:
return None
quality1, quality2, compare_result = self.model(image_tensor1, image_tensor2)
quality1 = quality1.squeeze().item()
quality2 = quality2.squeeze().item()
# Interpret the comparison result
# print(compare_result.shape)
compare_probs = torch.softmax(compare_result, dim=-1).squeeze(dim=0).cpu().numpy()
# print(compare_probs)
prediction = np.argmax(compare_probs)
# Align with training label semantics:
# dataset encodes prefs: A>B -> 1, A<B -> 0, equal -> 2
# So class 1 => Image 1 (A) is better, class 0 => Image 2 (B) is better
comparison_map = {0: 'Image 2 is better', 1: 'Image 1 is better', 2: 'Images are of similar quality'}
return {
'comparison': comparison_map[prediction],
'comparison_raw': compare_probs.tolist()}
|