orpheus0429 commited on
Commit
383998a
·
verified ·
1 Parent(s): aed57ce

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ FGResQ.png filter=lfs diff=lfs merge=lfs -text
FGResQ.png ADDED

Git LFS Details

  • SHA256: 4da25007e721f0197bec5591497f73a154fc5f8872e618a027f869c447bfb446
  • Pointer size: 132 Bytes
  • Size of remote file: 3.2 MB
model/FGResQ.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import timm
3
+ import torch
4
+ import torchvision
5
+ import torch.nn as nn
6
+ from transformers import CLIPVisionModel
7
+ # import open_clip
8
+ import torchvision.transforms as transforms
9
+ from PIL import Image
10
+ import cv2
11
+ import accelerate
12
+
13
+ def load_clip_model(clip_model="openai/ViT-B-16", clip_freeze=True, precision='fp16'):
14
+ pretrained, model_tag = clip_model.split('/')
15
+ pretrained = None if pretrained == 'None' else pretrained
16
+ # clip_model = open_clip.create_model(model_tag, precision=precision, pretrained=pretrained)
17
+ # clip_model = timm.create_model('timm/vit_base_patch16_clip_224.openai', pretrained=True, in_chans=3)
18
+ clip_model = CLIPVisionModel.from_pretrained(clip_model)
19
+ if clip_freeze:
20
+ for param in clip_model.parameters():
21
+ param.requires_grad = False
22
+
23
+ if model_tag == 'clip-vit-base-patch16':
24
+ feature_size = dict(global_feature=768, local_feature=[196, 768])
25
+ elif model_tag == 'ViT-L-14-quickgelu' or model_tag == 'ViT-L-14':
26
+ feature_size = dict(global_feature=768, local_feature=[256, 1024])
27
+ else:
28
+ raise ValueError(f"Unknown model_tag: {model_tag}")
29
+
30
+ return clip_model, feature_size
31
+
32
+ class DualBranch(nn.Module):
33
+
34
+ def __init__(self, clip_model="openai/clip-vit-base-patch16", clip_freeze=True, precision='fp16'):
35
+ super(DualBranch, self).__init__()
36
+ self.clip_freeze = clip_freeze
37
+
38
+ # Load CLIP model
39
+ self.clip_model, feature_size = load_clip_model(clip_model, clip_freeze, precision)
40
+
41
+ # Initialize CLIP vision model for task classification
42
+ self.task_cls_clip = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16")
43
+
44
+
45
+ self.head = nn.Linear(feature_size['global_feature']*3, 1)
46
+ self.compare_head =nn.Linear(feature_size['global_feature']*6, 3)
47
+
48
+
49
+ self.prompt = nn.Parameter(torch.rand(1, feature_size['global_feature']))
50
+ self.task_mlp = nn.Sequential(
51
+ nn.Linear(feature_size['global_feature'], feature_size['global_feature']),
52
+ nn.SiLU(False),
53
+ nn.Linear(feature_size['global_feature'], feature_size['global_feature']))
54
+ self.prompt_mlp = nn.Linear(feature_size['global_feature'], feature_size['global_feature'])
55
+
56
+ with torch.no_grad():
57
+ self.task_mlp[0].weight.fill_(0.0)
58
+ self.task_mlp[0].bias.fill_(0.0)
59
+ self.task_mlp[2].weight.fill_(0.0)
60
+ self.task_mlp[2].bias.fill_(0.0)
61
+ self.prompt_mlp.weight.fill_(0.0)
62
+ self.prompt_mlp.bias.fill_(0.0)
63
+
64
+ # Load pre-trained weights
65
+ self._load_pretrained_weights("./weights/Degradation.pth")
66
+
67
+
68
+ for param in self.task_cls_clip.parameters():
69
+ param.requires_grad = False
70
+
71
+ # Unfreeze the last two layers
72
+ for i in range(10, 12): # Layers 10 and 11
73
+ for param in self.task_cls_clip.vision_model.encoder.layers[i].parameters():
74
+ param.requires_grad = True
75
+ def _load_pretrained_weights(self, state_dict_path):
76
+ """
77
+ Load pre-trained weights, including the CLIP model and classification head.
78
+ """
79
+ # Load state dictionary
80
+ state_dict = torch.load(state_dict_path)
81
+
82
+ # Separate weights for CLIP model and classification head
83
+ clip_state_dict = {}
84
+
85
+ for key, value in state_dict.items():
86
+ if key.startswith('clip_model.'):
87
+ # Remove 'clip_model.' prefix for the CLIP model
88
+ new_key = key.replace('clip_model.', '')
89
+ clip_state_dict[new_key] = value
90
+ # elif key in ['head.weight', 'head.bias']:
91
+ # # Save weights for the classification head
92
+ # head_state_dict[key] = value
93
+
94
+ # Load weights for the CLIP model
95
+ self.task_cls_clip.load_state_dict(clip_state_dict, strict=False)
96
+ print("Successfully loaded CLIP model weights")
97
+
98
+ def forward(self, x0, x1 = None):
99
+ # features, _ = self.clip_model.encode_image(x)
100
+ if x1 is None:
101
+ # Image features
102
+ features0 = self.clip_model(x0)['pooler_output']
103
+ # Classification features
104
+ task_features0 = self.task_cls_clip(x0)['pooler_output']
105
+
106
+ # Learn classification features
107
+ task_embedding = torch.softmax(self.task_mlp(task_features0), dim=1) * self.prompt
108
+ task_embedding = self.prompt_mlp(task_embedding)
109
+
110
+ # features = torch.cat([features0, task_features], dim
111
+ features0 = torch.cat([features0, task_embedding, features0+task_embedding], dim=1)
112
+ quality = self.head(features0)
113
+ quality = nn.Sigmoid()(quality)
114
+
115
+ return quality, None, None
116
+ elif x1 is not None:
117
+ # features_, _ = self.clip_model.encode_image(x_local)
118
+ # Image features
119
+ features0 = self.clip_model(x0)['pooler_output']
120
+ features1 = self.clip_model(x1)['pooler_output']
121
+ # Classification features
122
+ task_features0 = self.task_cls_clip(x0)['pooler_output']
123
+ task_features1 = self.task_cls_clip(x1)['pooler_output']
124
+
125
+ task_embedding0 = torch.softmax(self.task_mlp(task_features0), dim=1) * self.prompt
126
+ task_embedding0 = self.prompt_mlp(task_embedding0)
127
+ task_embedding1 = torch.softmax(self.task_mlp(task_features1), dim=1) * self.prompt
128
+ task_embedding1 = self.prompt_mlp(task_embedding1)
129
+
130
+ features0 = torch.cat([features0, task_embedding0, features0+task_embedding0], dim=1)
131
+ features1 = torch.cat([features1, task_embedding1, features1+task_embedding1], dim=1)
132
+
133
+ # features0 = torch.cat([features0, task_features0], dim=
134
+ # import pdb; pdb.set_trace()
135
+ features = torch.cat([features0, features1], dim=1)
136
+ # features = torch.cat([features0, features1], dim=1)
137
+ compare_quality = self.compare_head(features)
138
+
139
+ # quality0 = self.head(features0)
140
+ # quality1 = self.head(features1)
141
+ quality0 = self.head(features0)
142
+ quality1 = self.head(features1)
143
+ quality0 = nn.Sigmoid()(quality0)
144
+ quality1 = nn.Sigmoid()(quality1)
145
+
146
+ # quality = {'quality0': quality0, 'quality1': quality1}
147
+
148
+ return quality0, quality1, compare_quality
149
+
150
+ class FGResQ:
151
+ def __init__(self, model_path, clip_model="openai/clip-vit-base-patch16", input_size=224, device=None):
152
+ """
153
+ Initializes the inference model.
154
+
155
+ Args:
156
+ model_path (str): Path to the pre-trained model checkpoint (.pth or .safetensors).
157
+ clip_model (str): Name of the CLIP model to use.
158
+ input_size (int): Input image size for the model.
159
+ device (str, optional): Device to run inference on ('cuda' or 'cpu'). Auto-detected if None.
160
+ """
161
+ if device is None:
162
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
163
+ else:
164
+ self.device = device
165
+
166
+ print(f"Using device: {self.device}")
167
+
168
+ # Load the model
169
+ self.model = DualBranch(clip_model=clip_model, clip_freeze=True, precision='fp32')
170
+ # self.model = self.accelerator.unwrap_model(self.model)
171
+ # Load model weights
172
+ try:
173
+ raw = torch.load(model_path, map_location=self.device)
174
+ # unwrap possible containers
175
+ if isinstance(raw, dict) and any(k in raw for k in ['model', 'state_dict']):
176
+ state_dict = raw.get('model', raw.get('state_dict', raw))
177
+ else:
178
+ state_dict = raw
179
+
180
+ # Only strip 'module.' if present; keep other namespaces intact
181
+ if any(k.startswith('module.') for k in state_dict.keys()):
182
+ state_dict = {k.replace('module.', '', 1): v for k, v in state_dict.items()}
183
+
184
+ missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
185
+ if missing:
186
+ print(f"[load_state_dict] Missing keys: {missing}")
187
+ if unexpected:
188
+ print(f"[load_state_dict] Unexpected keys: {unexpected}")
189
+ print(f"Model weights loaded from {model_path}")
190
+ except Exception as e:
191
+ print(f"Error loading model weights: {e}")
192
+ raise
193
+
194
+ self.model.to(self.device)
195
+ self.model.eval()
196
+
197
+ # Define image preprocessing
198
+ # Match training/validation pipeline: first unify to 256x256 (as in cls_model/dataset.py),
199
+ # then CenterCrop to input_size, followed by CLIP normalization.
200
+ self.transform = transforms.Compose([
201
+ transforms.ToTensor(),
202
+ transforms.CenterCrop(input_size),
203
+ transforms.Normalize(
204
+ mean=[0.48145466, 0.4578275, 0.40821073],
205
+ std=[0.26862954, 0.26130258, 0.27577711]
206
+ )
207
+ ])
208
+
209
+ def _preprocess_image(self, image_path):
210
+ """Load and preprocess a single image."""
211
+ try:
212
+ # Match training dataset loader: cv2 read + resize to 256x256 (INTER_LINEAR)
213
+ img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
214
+ if img is None:
215
+ raise FileNotFoundError(f"Failed to read image at {image_path}")
216
+ img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_LINEAR)
217
+ image = Image.fromarray(img)
218
+ image_tensor = self.transform(image).unsqueeze(0)
219
+ return image_tensor.to(self.device)
220
+ except FileNotFoundError:
221
+ print(f"Error: Image file not found at {image_path}")
222
+ return None
223
+ except Exception as e:
224
+ print(f"Error processing image {image_path}: {e}")
225
+ return None
226
+
227
+ @torch.no_grad()
228
+ def predict_single(self, image_path):
229
+ """
230
+ Predict the quality score of a single image.
231
+ """
232
+ image_tensor = self._preprocess_image(image_path)
233
+ if image_tensor is None:
234
+ return None
235
+
236
+ quality_score, _, _ = self.model(image_tensor)
237
+ return quality_score.squeeze().item()
238
+
239
+ @torch.no_grad()
240
+ def predict_pair(self, image_path1, image_path2):
241
+ """
242
+ Compare the quality of two images.
243
+ """
244
+ image_tensor1 = self._preprocess_image(image_path1)
245
+ image_tensor2 = self._preprocess_image(image_path2)
246
+
247
+ if image_tensor1 is None or image_tensor2 is None:
248
+ return None
249
+
250
+ quality1, quality2, compare_result = self.model(image_tensor1, image_tensor2)
251
+
252
+ quality1 = quality1.squeeze().item()
253
+ quality2 = quality2.squeeze().item()
254
+
255
+ # Interpret the comparison result
256
+ # print(compare_result.shape)
257
+ compare_probs = torch.softmax(compare_result, dim=-1).squeeze(dim=0).cpu().numpy()
258
+ # print(compare_probs)
259
+ prediction = np.argmax(compare_probs)
260
+
261
+ # Align with training label semantics:
262
+ # dataset encodes prefs: A>B -> 1, A<B -> 0, equal -> 2
263
+ # So class 1 => Image 1 (A) is better, class 0 => Image 2 (B) is better
264
+ comparison_map = {0: 'Image 2 is better', 1: 'Image 1 is better', 2: 'Images are of similar quality'}
265
+
266
+ return {
267
+ 'comparison': comparison_map[prediction],
268
+ 'comparison_raw': compare_probs.tolist()}
269
+
270
+
271
+
model/__pycache__/FGResQ.cpython-38.pyc ADDED
Binary file (7.23 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.28.0
2
+ certifi==2025.7.14
3
+ charset-normalizer==3.4.2
4
+ contourpy==1.1.1
5
+ cycler==0.12.1
6
+ einops==0.8.1
7
+ filelock==3.16.1
8
+ fonttools==4.57.0
9
+ fsspec==2025.3.0
10
+ hf-xet==1.1.5
11
+ huggingface-hub==0.33.4
12
+ idna==3.10
13
+ importlib_resources==6.4.5
14
+ kiwisolver==1.4.7
15
+ matplotlib==3.7.5
16
+ mpmath==1.3.0
17
+ numpy==1.24.4
18
+ opencv-python==4.8.1.78
19
+ packaging==25.0
20
+ pandas==2.0.3
21
+ pillow==10.4.0
22
+ prefetch-generator==1.0.3
23
+ psutil==7.0.0
24
+ pyparsing==3.1.4
25
+ python-dateutil==2.9.0.post0
26
+ pytz==2025.2
27
+ PyYAML==6.0.2
28
+ regex==2024.11.6
29
+ requests==2.32.4
30
+ safetensors==0.5.3
31
+ scipy==1.10.1
32
+ seaborn==0.13.2
33
+ six==1.17.0
34
+ sympy==1.12
35
+ timm==1.0.17
36
+ tokenizers==0.15.2
37
+ torch==1.13.0+cu117
38
+ torch-ema==0.3
39
+ torchaudio==0.13.0+cu117
40
+ torchvision==0.14.0+cu117
41
+ tqdm==4.67.1
42
+ transformers==4.36.1
43
+ typing_extensions==4.13.2
44
+ tzdata==2025.2
45
+ urllib3==2.2.3
46
+ zipp==3.20.2
weights/.gitkeep ADDED
File without changes