import numpy as np import timm import torch import torchvision import torch.nn as nn from transformers import CLIPVisionModel # import open_clip import torchvision.transforms as transforms from PIL import Image import cv2 import accelerate def load_clip_model(clip_model="openai/ViT-B-16", clip_freeze=True, precision='fp16'): pretrained, model_tag = clip_model.split('/') pretrained = None if pretrained == 'None' else pretrained # clip_model = open_clip.create_model(model_tag, precision=precision, pretrained=pretrained) # clip_model = timm.create_model('timm/vit_base_patch16_clip_224.openai', pretrained=True, in_chans=3) clip_model = CLIPVisionModel.from_pretrained(clip_model) if clip_freeze: for param in clip_model.parameters(): param.requires_grad = False if model_tag == 'clip-vit-base-patch16': feature_size = dict(global_feature=768, local_feature=[196, 768]) elif model_tag == 'ViT-L-14-quickgelu' or model_tag == 'ViT-L-14': feature_size = dict(global_feature=768, local_feature=[256, 1024]) else: raise ValueError(f"Unknown model_tag: {model_tag}") return clip_model, feature_size class DualBranch(nn.Module): def __init__(self, clip_model="openai/clip-vit-base-patch16", clip_freeze=True, precision='fp16'): super(DualBranch, self).__init__() self.clip_freeze = clip_freeze # Load CLIP model self.clip_model, feature_size = load_clip_model(clip_model, clip_freeze, precision) # Initialize CLIP vision model for task classification self.task_cls_clip = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16") self.head = nn.Linear(feature_size['global_feature']*3, 1) self.compare_head =nn.Linear(feature_size['global_feature']*6, 3) self.prompt = nn.Parameter(torch.rand(1, feature_size['global_feature'])) self.task_mlp = nn.Sequential( nn.Linear(feature_size['global_feature'], feature_size['global_feature']), nn.SiLU(False), nn.Linear(feature_size['global_feature'], feature_size['global_feature'])) self.prompt_mlp = nn.Linear(feature_size['global_feature'], feature_size['global_feature']) with torch.no_grad(): self.task_mlp[0].weight.fill_(0.0) self.task_mlp[0].bias.fill_(0.0) self.task_mlp[2].weight.fill_(0.0) self.task_mlp[2].bias.fill_(0.0) self.prompt_mlp.weight.fill_(0.0) self.prompt_mlp.bias.fill_(0.0) # Load pre-trained weights self._load_pretrained_weights("./weights/Degradation.pth") for param in self.task_cls_clip.parameters(): param.requires_grad = False # Unfreeze the last two layers for i in range(10, 12): # Layers 10 and 11 for param in self.task_cls_clip.vision_model.encoder.layers[i].parameters(): param.requires_grad = True def _load_pretrained_weights(self, state_dict_path): """ Load pre-trained weights, including the CLIP model and classification head. """ # Load state dictionary state_dict = torch.load(state_dict_path) # Separate weights for CLIP model and classification head clip_state_dict = {} for key, value in state_dict.items(): if key.startswith('clip_model.'): # Remove 'clip_model.' prefix for the CLIP model new_key = key.replace('clip_model.', '') clip_state_dict[new_key] = value # elif key in ['head.weight', 'head.bias']: # # Save weights for the classification head # head_state_dict[key] = value # Load weights for the CLIP model self.task_cls_clip.load_state_dict(clip_state_dict, strict=False) print("Successfully loaded CLIP model weights") def forward(self, x0, x1 = None): # features, _ = self.clip_model.encode_image(x) if x1 is None: # Image features features0 = self.clip_model(x0)['pooler_output'] # Classification features task_features0 = self.task_cls_clip(x0)['pooler_output'] # Learn classification features task_embedding = torch.softmax(self.task_mlp(task_features0), dim=1) * self.prompt task_embedding = self.prompt_mlp(task_embedding) # features = torch.cat([features0, task_features], dim features0 = torch.cat([features0, task_embedding, features0+task_embedding], dim=1) quality = self.head(features0) quality = nn.Sigmoid()(quality) return quality, None, None elif x1 is not None: # features_, _ = self.clip_model.encode_image(x_local) # Image features features0 = self.clip_model(x0)['pooler_output'] features1 = self.clip_model(x1)['pooler_output'] # Classification features task_features0 = self.task_cls_clip(x0)['pooler_output'] task_features1 = self.task_cls_clip(x1)['pooler_output'] task_embedding0 = torch.softmax(self.task_mlp(task_features0), dim=1) * self.prompt task_embedding0 = self.prompt_mlp(task_embedding0) task_embedding1 = torch.softmax(self.task_mlp(task_features1), dim=1) * self.prompt task_embedding1 = self.prompt_mlp(task_embedding1) features0 = torch.cat([features0, task_embedding0, features0+task_embedding0], dim=1) features1 = torch.cat([features1, task_embedding1, features1+task_embedding1], dim=1) # features0 = torch.cat([features0, task_features0], dim= # import pdb; pdb.set_trace() features = torch.cat([features0, features1], dim=1) # features = torch.cat([features0, features1], dim=1) compare_quality = self.compare_head(features) # quality0 = self.head(features0) # quality1 = self.head(features1) quality0 = self.head(features0) quality1 = self.head(features1) quality0 = nn.Sigmoid()(quality0) quality1 = nn.Sigmoid()(quality1) # quality = {'quality0': quality0, 'quality1': quality1} return quality0, quality1, compare_quality class FGResQ: def __init__(self, model_path, clip_model="openai/clip-vit-base-patch16", input_size=224, device=None): """ Initializes the inference model. Args: model_path (str): Path to the pre-trained model checkpoint (.pth or .safetensors). clip_model (str): Name of the CLIP model to use. input_size (int): Input image size for the model. device (str, optional): Device to run inference on ('cuda' or 'cpu'). Auto-detected if None. """ if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device print(f"Using device: {self.device}") # Load the model self.model = DualBranch(clip_model=clip_model, clip_freeze=True, precision='fp32') # self.model = self.accelerator.unwrap_model(self.model) # Load model weights try: raw = torch.load(model_path, map_location=self.device) # unwrap possible containers if isinstance(raw, dict) and any(k in raw for k in ['model', 'state_dict']): state_dict = raw.get('model', raw.get('state_dict', raw)) else: state_dict = raw # Only strip 'module.' if present; keep other namespaces intact if any(k.startswith('module.') for k in state_dict.keys()): state_dict = {k.replace('module.', '', 1): v for k, v in state_dict.items()} missing, unexpected = self.model.load_state_dict(state_dict, strict=False) if missing: print(f"[load_state_dict] Missing keys: {missing}") if unexpected: print(f"[load_state_dict] Unexpected keys: {unexpected}") print(f"Model weights loaded from {model_path}") except Exception as e: print(f"Error loading model weights: {e}") raise self.model.to(self.device) self.model.eval() # Define image preprocessing # Match training/validation pipeline: first unify to 256x256 (as in cls_model/dataset.py), # then CenterCrop to input_size, followed by CLIP normalization. self.transform = transforms.Compose([ transforms.ToTensor(), transforms.CenterCrop(input_size), transforms.Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] ) ]) def _preprocess_image(self, image_path): """Load and preprocess a single image.""" try: # Match training dataset loader: cv2 read + resize to 256x256 (INTER_LINEAR) img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) if img is None: raise FileNotFoundError(f"Failed to read image at {image_path}") img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_LINEAR) image = Image.fromarray(img) image_tensor = self.transform(image).unsqueeze(0) return image_tensor.to(self.device) except FileNotFoundError: print(f"Error: Image file not found at {image_path}") return None except Exception as e: print(f"Error processing image {image_path}: {e}") return None @torch.no_grad() def predict_single(self, image_path): """ Predict the quality score of a single image. """ image_tensor = self._preprocess_image(image_path) if image_tensor is None: return None quality_score, _, _ = self.model(image_tensor) return quality_score.squeeze().item() @torch.no_grad() def predict_pair(self, image_path1, image_path2): """ Compare the quality of two images. """ image_tensor1 = self._preprocess_image(image_path1) image_tensor2 = self._preprocess_image(image_path2) if image_tensor1 is None or image_tensor2 is None: return None quality1, quality2, compare_result = self.model(image_tensor1, image_tensor2) quality1 = quality1.squeeze().item() quality2 = quality2.squeeze().item() # Interpret the comparison result # print(compare_result.shape) compare_probs = torch.softmax(compare_result, dim=-1).squeeze(dim=0).cpu().numpy() # print(compare_probs) prediction = np.argmax(compare_probs) # Align with training label semantics: # dataset encodes prefs: A>B -> 1, A 0, equal -> 2 # So class 1 => Image 1 (A) is better, class 0 => Image 2 (B) is better comparison_map = {0: 'Image 2 is better', 1: 'Image 1 is better', 2: 'Images are of similar quality'} return { 'comparison': comparison_map[prediction], 'comparison_raw': compare_probs.tolist()}