orpheus0429 commited on
Commit
0ed1cf6
·
verified ·
1 Parent(s): f392037

Upload 5 files

Browse files
README.md CHANGED
@@ -1,14 +1,136 @@
1
- ---
2
- title: FGResQ
3
- emoji: 🐢
4
- colorFrom: purple
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: FGResQ office demo
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <a href="https://arxiv.org/abs/2508.14475"><img src="https://img.shields.io/badge/Arxiv-preprint-red"></a>
3
+ <a href="https://pxf0429.github.io/FGResQ/"><img src="https://img.shields.io/badge/Homepage-green"></a>
4
+ <a href='https://github.com/sxfly99/FGRestore/stargazers'><img src='https://img.shields.io/github/stars/sxfly99/FGRestore.svg?style=social'></a>
5
+ </div>
6
+
7
+ <h1 align="center">Fine-grained Image Quality Assessment for Perceptual Image Restoration</h1>
8
+
9
+ <div align="center">
10
+ <a href="https://github.com/sxfly99">Xiangfei Sheng</a><sup>1*</sup>,
11
+ <a href="https://github.com/pxf0429">Xiaofeng Pan</a><sup>1*</sup>,
12
+ <a href="https://github.com/yzc-ippl">Zhichao Yang</a><sup>1</sup>,
13
+ <a href="https://faculty.xidian.edu.cn/cpf/">Pengfei Chen</a><sup>1</sup>,
14
+ <a href="https://web.xidian.edu.cn/ldli/">Leida Li</a><sup>1#</sup>
15
+ </div>
16
+
17
+ <div align="center">
18
+ <sup>1</sup>School of Artificial Intelligence, Xidian University
19
+ </div>
20
+
21
+ <div align="center">
22
+ <sup>*</sup>Equal contribution. <sup>#</sup>Corresponding author.
23
+ </div>
24
+
25
+
26
+ <div align="center">
27
+ <img src="FGResQ.png" width="800"/>
28
+ </div>
29
+
30
+ <div style="font-family: sans-serif; margin-bottom: 2em;">
31
+ <h2 style="border-bottom: 1px solid #eaecef; padding-bottom: 0.3em; margin-bottom: 1em;">📰 News</h2>
32
+ <ul style="list-style-type: none; padding-left: 0;">
33
+ <li style="margin-bottom: 0.8em;">
34
+ <strong>[2025-11-08]</strong> 🎉🎉🎉Our paper, "Fine-grained Image Quality Assessment for Perceptual Image Restoration", has been accepted to appear at AAAI 2026!
35
+ </li>
36
+ <li style="margin-bottom: 0.8em;">
37
+ <strong>[2025-08-20]</strong> Code and pre-trained models for FGResQ released.
38
+ </li>
39
+ </ul>
40
+ </div>
41
+
42
+
43
+ ## Quick Start
44
+
45
+ This guide will help you get started with the FGResQ inference code.
46
+
47
+ ### 1. Installation
48
+
49
+ First, clone the repository and install the required dependencies.
50
+
51
+ ```bash
52
+ git clone https://github.com/sxfly99/FGResQ.git
53
+ cd FGResQ
54
+ pip install -r requirements.txt
55
+ ```
56
+
57
+ ### 2. Download Pre-trained Weights
58
+
59
+ You can download the pre-trained model weights from the following link:
60
+ [**Download Weights (Google Drive)**](https://drive.google.com/drive/folders/10MVnAoEIDZ08Rek4qkStGDY0qLiWUahJ?usp=drive_link) or [**(Baidu Netdisk)**](https://pan.baidu.com/s/1a2IZbr_PrgZYCbUbjKLykA?pwd=9ivu)
61
+
62
+ Place the downloaded files in the `weights` directory.
63
+
64
+ - `FGResQ.pth`: The main model for quality scoring and ranking.
65
+ - `Degradation.pth`: The weights for the degradation-aware task branch.
66
+
67
+ Create the `weights` directory if it doesn't exist and place the files inside.
68
+
69
+ ```
70
+ FGRestore/
71
+ |-- weights/
72
+ | |-- FGResQ.pth
73
+ | |-- Degradation.pth
74
+ |-- model/
75
+ | |-- FGResQ.py
76
+ |-- requirements.txt
77
+ |-- README.md
78
+ ```
79
+
80
+ ## Usage
81
+
82
+ The `FGResQ` provides two main functionalities: scoring a single image and comparing a pair of images.
83
+
84
+ ### Initialize the Scorer
85
+
86
+ First, import and initialize the `FGResQ`.
87
+
88
+ ```python
89
+ from model.FGResQ import FGResQ
90
+
91
+ # Path to the main model weights
92
+ model_path = "weights/FGResQ.pth"
93
+
94
+ # Initialize the inference engine
95
+ model = FGResQ(model_path=model_path)
96
+ ```
97
+
98
+ ### 1. Single Image Input Mode: Quality Scoring
99
+
100
+ You can get a quality score for a single image. The score typically ranges from 0 to 1, where a higher score indicates better quality.
101
+
102
+ ```python
103
+ image_path = "path/to/your/image.jpg"
104
+ quality_score = model.predict_single(image_path)
105
+ print(f"The quality score for the image is: {quality_score:.4f}")
106
+ ```
107
+
108
+ ### 2. Pairwise Image Input Mode: Quality Ranking
109
+
110
+ You can also compare two images to determine which one has better quality.
111
+
112
+ ```python
113
+ image_path1 = "path/to/image1.jpg"
114
+ image_path2 = "path/to/image2.jpg"
115
+
116
+ comparison_result = model.predict_pair(image_path1, image_path2)
117
+
118
+ # The result includes a human-readable comparison and raw probabilities
119
+ print(f"Comparison: {comparison_result['comparison']}")
120
+ # Example output: "Comparison: Image 1 is better"
121
+
122
+ print(f"Raw output probabilities: {comparison_result['comparison_raw']}")
123
+ # Example output: "[0.8, 0.1, 0.1]" (Probabilities for Image1 > Image2, Image2 > Image1, Image1 ≈ Image2)
124
+ ```
125
+ ## Citation
126
+
127
+ If you find this work is useful, pleaes cite our paper!
128
+
129
+ ```bibtex
130
+
131
+ @article{sheng2025fgresq,
132
+ title={Fine-grained Image Quality Assessment for Perceptual Image Restoration},
133
+ author={Sheng, Xiangfei and Pan, Xiaofeng and Yang, Zhichao and Chen, Pengfei and Li, Leida},
134
+ journal={arXiv preprint arXiv:2508.14475},
135
+ year={2025}
136
+ }
model/FGResQ.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import timm
3
+ import torch
4
+ import torchvision
5
+ import torch.nn as nn
6
+ from transformers import CLIPVisionModel
7
+ # import open_clip
8
+ import torchvision.transforms as transforms
9
+ from PIL import Image
10
+ import cv2
11
+ import accelerate
12
+
13
+ def load_clip_model(clip_model="openai/ViT-B-16", clip_freeze=True, precision='fp16'):
14
+ pretrained, model_tag = clip_model.split('/')
15
+ pretrained = None if pretrained == 'None' else pretrained
16
+ # clip_model = open_clip.create_model(model_tag, precision=precision, pretrained=pretrained)
17
+ # clip_model = timm.create_model('timm/vit_base_patch16_clip_224.openai', pretrained=True, in_chans=3)
18
+ clip_model = CLIPVisionModel.from_pretrained(clip_model)
19
+ if clip_freeze:
20
+ for param in clip_model.parameters():
21
+ param.requires_grad = False
22
+
23
+ if model_tag == 'clip-vit-base-patch16':
24
+ feature_size = dict(global_feature=768, local_feature=[196, 768])
25
+ elif model_tag == 'ViT-L-14-quickgelu' or model_tag == 'ViT-L-14':
26
+ feature_size = dict(global_feature=768, local_feature=[256, 1024])
27
+ else:
28
+ raise ValueError(f"Unknown model_tag: {model_tag}")
29
+
30
+ return clip_model, feature_size
31
+
32
+ class DualBranch(nn.Module):
33
+
34
+ def __init__(self, clip_model="openai/clip-vit-base-patch16", clip_freeze=True, precision='fp16'):
35
+ super(DualBranch, self).__init__()
36
+ self.clip_freeze = clip_freeze
37
+
38
+ # Load CLIP model
39
+ self.clip_model, feature_size = load_clip_model(clip_model, clip_freeze, precision)
40
+
41
+ # Initialize CLIP vision model for task classification
42
+ self.task_cls_clip = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16")
43
+
44
+
45
+ self.head = nn.Linear(feature_size['global_feature']*3, 1)
46
+ self.compare_head =nn.Linear(feature_size['global_feature']*6, 3)
47
+
48
+
49
+ self.prompt = nn.Parameter(torch.rand(1, feature_size['global_feature']))
50
+ self.task_mlp = nn.Sequential(
51
+ nn.Linear(feature_size['global_feature'], feature_size['global_feature']),
52
+ nn.SiLU(False),
53
+ nn.Linear(feature_size['global_feature'], feature_size['global_feature']))
54
+ self.prompt_mlp = nn.Linear(feature_size['global_feature'], feature_size['global_feature'])
55
+
56
+ with torch.no_grad():
57
+ self.task_mlp[0].weight.fill_(0.0)
58
+ self.task_mlp[0].bias.fill_(0.0)
59
+ self.task_mlp[2].weight.fill_(0.0)
60
+ self.task_mlp[2].bias.fill_(0.0)
61
+ self.prompt_mlp.weight.fill_(0.0)
62
+ self.prompt_mlp.bias.fill_(0.0)
63
+
64
+ # Load pre-trained weights
65
+ self._load_pretrained_weights("./weights/Degradation.pth")
66
+
67
+
68
+ for param in self.task_cls_clip.parameters():
69
+ param.requires_grad = False
70
+
71
+ # Unfreeze the last two layers
72
+ for i in range(10, 12): # Layers 10 and 11
73
+ for param in self.task_cls_clip.vision_model.encoder.layers[i].parameters():
74
+ param.requires_grad = True
75
+ def _load_pretrained_weights(self, state_dict_path):
76
+ """
77
+ Load pre-trained weights, including the CLIP model and classification head.
78
+ """
79
+ # Load state dictionary
80
+ state_dict = torch.load(state_dict_path)
81
+
82
+ # Separate weights for CLIP model and classification head
83
+ clip_state_dict = {}
84
+
85
+ for key, value in state_dict.items():
86
+ if key.startswith('clip_model.'):
87
+ # Remove 'clip_model.' prefix for the CLIP model
88
+ new_key = key.replace('clip_model.', '')
89
+ clip_state_dict[new_key] = value
90
+ # elif key in ['head.weight', 'head.bias']:
91
+ # # Save weights for the classification head
92
+ # head_state_dict[key] = value
93
+
94
+ # Load weights for the CLIP model
95
+ self.task_cls_clip.load_state_dict(clip_state_dict, strict=False)
96
+ print("Successfully loaded CLIP model weights")
97
+
98
+ def forward(self, x0, x1 = None):
99
+ # features, _ = self.clip_model.encode_image(x)
100
+ if x1 is None:
101
+ # Image features
102
+ features0 = self.clip_model(x0)['pooler_output']
103
+ # Classification features
104
+ task_features0 = self.task_cls_clip(x0)['pooler_output']
105
+
106
+ # Learn classification features
107
+ task_embedding = torch.softmax(self.task_mlp(task_features0), dim=1) * self.prompt
108
+ task_embedding = self.prompt_mlp(task_embedding)
109
+
110
+ # features = torch.cat([features0, task_features], dim
111
+ features0 = torch.cat([features0, task_embedding, features0+task_embedding], dim=1)
112
+ quality = self.head(features0)
113
+ quality = nn.Sigmoid()(quality)
114
+
115
+ return quality, None, None
116
+ elif x1 is not None:
117
+ # features_, _ = self.clip_model.encode_image(x_local)
118
+ # Image features
119
+ features0 = self.clip_model(x0)['pooler_output']
120
+ features1 = self.clip_model(x1)['pooler_output']
121
+ # Classification features
122
+ task_features0 = self.task_cls_clip(x0)['pooler_output']
123
+ task_features1 = self.task_cls_clip(x1)['pooler_output']
124
+
125
+ task_embedding0 = torch.softmax(self.task_mlp(task_features0), dim=1) * self.prompt
126
+ task_embedding0 = self.prompt_mlp(task_embedding0)
127
+ task_embedding1 = torch.softmax(self.task_mlp(task_features1), dim=1) * self.prompt
128
+ task_embedding1 = self.prompt_mlp(task_embedding1)
129
+
130
+ features0 = torch.cat([features0, task_embedding0, features0+task_embedding0], dim=1)
131
+ features1 = torch.cat([features1, task_embedding1, features1+task_embedding1], dim=1)
132
+
133
+ # features0 = torch.cat([features0, task_features0], dim=
134
+ # import pdb; pdb.set_trace()
135
+ features = torch.cat([features0, features1], dim=1)
136
+ # features = torch.cat([features0, features1], dim=1)
137
+ compare_quality = self.compare_head(features)
138
+
139
+ # quality0 = self.head(features0)
140
+ # quality1 = self.head(features1)
141
+ quality0 = self.head(features0)
142
+ quality1 = self.head(features1)
143
+ quality0 = nn.Sigmoid()(quality0)
144
+ quality1 = nn.Sigmoid()(quality1)
145
+
146
+ # quality = {'quality0': quality0, 'quality1': quality1}
147
+
148
+ return quality0, quality1, compare_quality
149
+
150
+ class FGResQ:
151
+ def __init__(self, model_path, clip_model="openai/clip-vit-base-patch16", input_size=224, device=None):
152
+ """
153
+ Initializes the inference model.
154
+
155
+ Args:
156
+ model_path (str): Path to the pre-trained model checkpoint (.pth or .safetensors).
157
+ clip_model (str): Name of the CLIP model to use.
158
+ input_size (int): Input image size for the model.
159
+ device (str, optional): Device to run inference on ('cuda' or 'cpu'). Auto-detected if None.
160
+ """
161
+ if device is None:
162
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
163
+ else:
164
+ self.device = device
165
+
166
+ print(f"Using device: {self.device}")
167
+
168
+ # Load the model
169
+ self.model = DualBranch(clip_model=clip_model, clip_freeze=True, precision='fp32')
170
+ # self.model = self.accelerator.unwrap_model(self.model)
171
+ # Load model weights
172
+ try:
173
+ raw = torch.load(model_path, map_location=self.device)
174
+ # unwrap possible containers
175
+ if isinstance(raw, dict) and any(k in raw for k in ['model', 'state_dict']):
176
+ state_dict = raw.get('model', raw.get('state_dict', raw))
177
+ else:
178
+ state_dict = raw
179
+
180
+ # Only strip 'module.' if present; keep other namespaces intact
181
+ if any(k.startswith('module.') for k in state_dict.keys()):
182
+ state_dict = {k.replace('module.', '', 1): v for k, v in state_dict.items()}
183
+
184
+ missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
185
+ if missing:
186
+ print(f"[load_state_dict] Missing keys: {missing}")
187
+ if unexpected:
188
+ print(f"[load_state_dict] Unexpected keys: {unexpected}")
189
+ print(f"Model weights loaded from {model_path}")
190
+ except Exception as e:
191
+ print(f"Error loading model weights: {e}")
192
+ raise
193
+
194
+ self.model.to(self.device)
195
+ self.model.eval()
196
+
197
+ # Define image preprocessing
198
+ # Match training/validation pipeline: first unify to 256x256 (as in cls_model/dataset.py),
199
+ # then CenterCrop to input_size, followed by CLIP normalization.
200
+ self.transform = transforms.Compose([
201
+ transforms.ToTensor(),
202
+ transforms.CenterCrop(input_size),
203
+ transforms.Normalize(
204
+ mean=[0.48145466, 0.4578275, 0.40821073],
205
+ std=[0.26862954, 0.26130258, 0.27577711]
206
+ )
207
+ ])
208
+
209
+ def _preprocess_image(self, image_path):
210
+ """Load and preprocess a single image."""
211
+ try:
212
+ # Match training dataset loader: cv2 read + resize to 256x256 (INTER_LINEAR)
213
+ img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
214
+ if img is None:
215
+ raise FileNotFoundError(f"Failed to read image at {image_path}")
216
+ img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_LINEAR)
217
+ image = Image.fromarray(img)
218
+ image_tensor = self.transform(image).unsqueeze(0)
219
+ return image_tensor.to(self.device)
220
+ except FileNotFoundError:
221
+ print(f"Error: Image file not found at {image_path}")
222
+ return None
223
+ except Exception as e:
224
+ print(f"Error processing image {image_path}: {e}")
225
+ return None
226
+
227
+ @torch.no_grad()
228
+ def predict_single(self, image_path):
229
+ """
230
+ Predict the quality score of a single image.
231
+ """
232
+ image_tensor = self._preprocess_image(image_path)
233
+ if image_tensor is None:
234
+ return None
235
+
236
+ quality_score, _, _ = self.model(image_tensor)
237
+ return quality_score.squeeze().item()
238
+
239
+ @torch.no_grad()
240
+ def predict_pair(self, image_path1, image_path2):
241
+ """
242
+ Compare the quality of two images.
243
+ """
244
+ image_tensor1 = self._preprocess_image(image_path1)
245
+ image_tensor2 = self._preprocess_image(image_path2)
246
+
247
+ if image_tensor1 is None or image_tensor2 is None:
248
+ return None
249
+
250
+ quality1, quality2, compare_result = self.model(image_tensor1, image_tensor2)
251
+
252
+ quality1 = quality1.squeeze().item()
253
+ quality2 = quality2.squeeze().item()
254
+
255
+ # Interpret the comparison result
256
+ # print(compare_result.shape)
257
+ compare_probs = torch.softmax(compare_result, dim=-1).squeeze(dim=0).cpu().numpy()
258
+ # print(compare_probs)
259
+ prediction = np.argmax(compare_probs)
260
+
261
+ # Align with training label semantics:
262
+ # dataset encodes prefs: A>B -> 1, A<B -> 0, equal -> 2
263
+ # So class 1 => Image 1 (A) is better, class 0 => Image 2 (B) is better
264
+ comparison_map = {0: 'Image 2 is better', 1: 'Image 1 is better', 2: 'Images are of similar quality'}
265
+
266
+ return {
267
+ 'comparison': comparison_map[prediction],
268
+ 'comparison_raw': compare_probs.tolist()}
269
+
270
+
271
+
model/__pycache__/FGResQ.cpython-38.pyc ADDED
Binary file (7.23 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.28.0
2
+ certifi==2025.7.14
3
+ charset-normalizer==3.4.2
4
+ contourpy==1.1.1
5
+ cycler==0.12.1
6
+ einops==0.8.1
7
+ filelock==3.16.1
8
+ fonttools==4.57.0
9
+ fsspec==2025.3.0
10
+ hf-xet==1.1.5
11
+ huggingface-hub==0.33.4
12
+ idna==3.10
13
+ importlib_resources==6.4.5
14
+ kiwisolver==1.4.7
15
+ matplotlib==3.7.5
16
+ mpmath==1.3.0
17
+ numpy==1.24.4
18
+ opencv-python==4.8.1.78
19
+ packaging==25.0
20
+ pandas==2.0.3
21
+ pillow==10.4.0
22
+ prefetch-generator==1.0.3
23
+ psutil==7.0.0
24
+ pyparsing==3.1.4
25
+ python-dateutil==2.9.0.post0
26
+ pytz==2025.2
27
+ PyYAML==6.0.2
28
+ regex==2024.11.6
29
+ requests==2.32.4
30
+ safetensors==0.5.3
31
+ scipy==1.10.1
32
+ seaborn==0.13.2
33
+ six==1.17.0
34
+ sympy==1.12
35
+ timm==1.0.17
36
+ tokenizers==0.15.2
37
+ torch==1.13.0+cu117
38
+ torch-ema==0.3
39
+ torchaudio==0.13.0+cu117
40
+ torchvision==0.14.0+cu117
41
+ tqdm==4.67.1
42
+ transformers==4.36.1
43
+ typing_extensions==4.13.2
44
+ tzdata==2025.2
45
+ urllib3==2.2.3
46
+ zipp==3.20.2
weights/.gitkeep ADDED
File without changes