S-Rajesh commited on
Commit
d9a1fb2
·
verified ·
1 Parent(s): 2ed357d

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import pickle
5
+ from PIL import Image
6
+ import os
7
+ from convnext_original import ConvNeXt as ConvNeXtOriginal
8
+ from convnext_finetune import ConvNeXt
9
+
10
+ # Global variables for models
11
+ content_model = None
12
+ quality_model = None
13
+ scaler = None
14
+ regression_model = None
15
+ device = None
16
+
17
+ def get_activation(name, activations):
18
+ """Hook function to capture activations."""
19
+ def hook(model, input, output):
20
+ activations[name] = output.detach()
21
+ return hook
22
+
23
+ def register_hooks(model):
24
+ """Register hooks for each layer in the model."""
25
+ activations = {}
26
+ for name, module in model.named_modules():
27
+ module.register_forward_hook(get_activation(name, activations))
28
+ return activations
29
+
30
+ def preprocess_image(image):
31
+ """Preprocess image for model input."""
32
+ # ImageNet normalization parameters
33
+ mean = np.array([0.485, 0.456, 0.406])
34
+ std = np.array([0.229, 0.224, 0.225])
35
+
36
+ img_array = np.array(image, dtype=np.float32) / 255.0
37
+ img_array = (img_array - mean) / std
38
+ return torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float()
39
+
40
+ def load_models():
41
+ """Load all required models."""
42
+ global content_model, quality_model, scaler, regression_model, device
43
+
44
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ # Check if model files exist
47
+ required_files = [
48
+ 'feature_models/convnext_tiny_22k_224.pth',
49
+ 'feature_models/triqa_quality_aware.pth',
50
+ 'Regression_Models/KonIQ_scaler.save',
51
+ 'Regression_Models/KonIQ_TRIQA.save'
52
+ ]
53
+
54
+ missing_files = [f for f in required_files if not os.path.exists(f)]
55
+ if missing_files:
56
+ print(f"Missing model files: {missing_files}")
57
+ print("Please download model files from the Box link and place them in the correct directories.")
58
+ return None, None
59
+
60
+ try:
61
+ # Load content-aware model (using original ConvNeXt)
62
+ content_model = ConvNeXtOriginal(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
63
+ content_state_dict = torch.load('feature_models/convnext_tiny_22k_224.pth', map_location=device)['model']
64
+ content_state_dict = {k: v for k, v in content_state_dict.items() if not k.startswith('head.')}
65
+ content_model.load_state_dict(content_state_dict, strict=False)
66
+ content_model.to(device).eval()
67
+
68
+ # Load quality-aware model
69
+ quality_model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
70
+ quality_state_dict = torch.load('feature_models/triqa_quality_aware.pth', map_location=device)
71
+ quality_model.load_state_dict(quality_state_dict, strict=True)
72
+ quality_model.to(device).eval()
73
+
74
+ # Register hooks for feature extraction
75
+ content_activations = register_hooks(content_model)
76
+ quality_activations = register_hooks(quality_model)
77
+
78
+ # Load scaler and regression model
79
+ with open('Regression_Models/KonIQ_scaler.save', 'rb') as f:
80
+ scaler = pickle.load(f)
81
+ with open('Regression_Models/KonIQ_TRIQA.save', 'rb') as f:
82
+ regression_model = pickle.load(f)
83
+
84
+ return content_activations, quality_activations
85
+ except Exception as e:
86
+ print(f"Error loading models: {e}")
87
+ return None, None
88
+
89
+ def predict_quality(image):
90
+ """Predict image quality score on 1-5 scale."""
91
+ global content_model, quality_model, scaler, regression_model, device
92
+
93
+ if content_model is None or quality_model is None:
94
+ return "Models not loaded. Please wait..."
95
+
96
+ # Load and preprocess image
97
+ image_half = image.resize((image.size[0]//2, image.size[1]//2), Image.LANCZOS)
98
+
99
+ img_full = preprocess_image(image).to(device)
100
+ img_half = preprocess_image(image_half).to(device)
101
+
102
+ with torch.no_grad():
103
+ # Extract content features using hooks
104
+ _ = content_model(img_full)
105
+ content_full = content_model.activations['norm'].cpu().numpy().flatten()
106
+
107
+ _ = content_model(img_half)
108
+ content_half = content_model.activations['norm'].cpu().numpy().flatten()
109
+
110
+ content_features = np.concatenate([content_full, content_half])
111
+
112
+ # Extract quality features using hooks
113
+ _ = quality_model(img_full)
114
+ quality_full = quality_model.activations['norm'].cpu().numpy().flatten()
115
+
116
+ _ = quality_model(img_half)
117
+ quality_half = quality_model.activations['norm'].cpu().numpy().flatten()
118
+
119
+ quality_features = np.concatenate([quality_full, quality_half])
120
+
121
+ # Combine features and predict
122
+ combined_features = np.concatenate([content_features, quality_features])
123
+ normalized_features = scaler.transform(combined_features.reshape(1, -1))
124
+ quality_score = regression_model.predict(normalized_features)[0]
125
+
126
+ return f"Quality Score: {quality_score:.2f}/5.0"
127
+
128
+ def create_demo():
129
+ """Create the Gradio demo interface."""
130
+
131
+ # Load models
132
+ try:
133
+ content_activations, quality_activations = load_models()
134
+ content_model.activations = content_activations
135
+ quality_model.activations = quality_activations
136
+ print("Models loaded successfully!")
137
+ except Exception as e:
138
+ print(f"Error loading models: {e}")
139
+ return None
140
+
141
+ # Create Gradio interface
142
+ with gr.Blocks(title="TRIQA: Image Quality Assessment", theme=gr.themes.Soft()) as demo:
143
+ gr.Markdown("""
144
+ # TRIQA: Image Quality Assessment
145
+
146
+ **TRIQA** combines content-aware and quality-aware features from ConvNeXt models to predict image quality scores on a 1-5 scale.
147
+
148
+ ### How to use:
149
+ 1. Upload an image using the file uploader below
150
+ 2. Click "Assess Quality" to get the quality score
151
+ 3. The score ranges from 1-5, where 5 represents the highest quality
152
+
153
+ ### Paper Links:
154
+ - **arXiv**: [https://arxiv.org/pdf/2507.12687](https://arxiv.org/pdf/2507.12687)
155
+ - **IEEE Xplore**: [https://ieeexplore.ieee.org/abstract/document/11084443](https://ieeexplore.ieee.org/abstract/document/11084443)
156
+ """)
157
+
158
+ with gr.Row():
159
+ with gr.Column():
160
+ input_image = gr.Image(
161
+ label="Upload Image",
162
+ type="pil",
163
+ height=400
164
+ )
165
+ submit_btn = gr.Button("Assess Quality", variant="primary")
166
+
167
+ with gr.Column():
168
+ output_text = gr.Textbox(
169
+ label="Quality Assessment Result",
170
+ value="Upload an image and click 'Assess Quality' to get the quality score.",
171
+ interactive=False
172
+ )
173
+
174
+ gr.Examples(
175
+ examples=[
176
+ ["sample_image/233045618.jpg"],
177
+ ["sample_image/25239707.jpg"],
178
+ ["sample_image/44009500.jpg"],
179
+ ["sample_image/5129172.jpg"],
180
+ ["sample_image/85119046.jpg"]
181
+ ],
182
+ inputs=input_image,
183
+ label="Sample Images"
184
+ )
185
+
186
+ submit_btn.click(
187
+ fn=predict_quality,
188
+ inputs=input_image,
189
+ outputs=output_text
190
+ )
191
+
192
+ gr.Markdown("""
193
+ ### Citation:
194
+ If you use this code in your research, please cite our paper:
195
+
196
+ ```bibtex
197
+ @INPROCEEDINGS{11084443,
198
+ author={Sureddi, Rajesh and Zadtootaghaj, Saman and Barman, Nabajeet and Bovik, Alan C.},
199
+ booktitle={2025 IEEE International Conference on Image Processing (ICIP)},
200
+ title={Triqa: Image Quality Assessment by Contrastive Pretraining on Ordered Distortion Triplets},
201
+ year={2025},
202
+ volume={},
203
+ number={},
204
+ pages={1744-1749},
205
+ keywords={Image quality;Training;Deep learning;Contrastive learning;Predictive models;Feature extraction;Distortion;Data models;Synthetic data;Image Quality Assessment;Contrastive Learning},
206
+ doi={10.1109/ICIP55913.2025.11084443}}
207
+ ```
208
+ """)
209
+
210
+ return demo
211
+
212
+ if __name__ == "__main__":
213
+ demo = create_demo()
214
+ if demo:
215
+ demo.launch(server_name="0.0.0.0", server_port=7860)
216
+ else:
217
+ print("Failed to create demo. Please check model files.")