VOIDER commited on
Commit
f89e218
·
verified ·
1 Parent(s): 55b8637

Upload 11 files

Browse files
app.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main application file for the Image Evaluator tool.
3
+ This module integrates all components and provides a Gradio interface.
4
+ """
5
+
6
+ import os
7
+ import gradio as gr
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import glob
12
+ from PIL import Image
13
+ import json
14
+ import tempfile
15
+ import shutil
16
+ from datetime import datetime
17
+
18
+ # Import custom modules
19
+ from modules.metadata_extractor import MetadataExtractor
20
+ from modules.technical_metrics import TechnicalMetrics
21
+ from modules.aesthetic_metrics import AestheticMetrics
22
+ from modules.aggregator import ResultsAggregator
23
+ from modules.visualizer import Visualizer
24
+
25
+
26
+ class ImageEvaluator:
27
+ """Main class for the Image Evaluator application."""
28
+
29
+ def __init__(self):
30
+ """Initialize the Image Evaluator."""
31
+ self.results_dir = os.path.join(os.getcwd(), "results")
32
+ os.makedirs(self.results_dir, exist_ok=True)
33
+
34
+ # Initialize components
35
+ self.metadata_extractor = MetadataExtractor()
36
+ self.technical_metrics = TechnicalMetrics()
37
+ self.aesthetic_metrics = AestheticMetrics()
38
+ self.aggregator = ResultsAggregator()
39
+ self.visualizer = Visualizer(self.results_dir)
40
+
41
+ # Storage for results
42
+ self.evaluation_results = {}
43
+ self.metadata_cache = {}
44
+ self.current_comparison = None
45
+
46
+ def process_images(self, image_files, progress=None):
47
+ """
48
+ Process a list of image files and extract metadata.
49
+
50
+ Args:
51
+ image_files: list of image file paths
52
+ progress: optional gradio Progress object
53
+
54
+ Returns:
55
+ tuple: (metadata_by_model, metadata_by_prompt)
56
+ """
57
+ metadata_list = []
58
+
59
+ total_files = len(image_files)
60
+ for i, img_path in enumerate(image_files):
61
+ if progress:
62
+ progress(i / total_files, f"Processing image {i+1}/{total_files}")
63
+
64
+ # Extract metadata
65
+ metadata = self.metadata_extractor.extract_metadata(img_path)
66
+ metadata_list.append((img_path, metadata))
67
+
68
+ # Cache metadata
69
+ self.metadata_cache[img_path] = metadata
70
+
71
+ # Group by model and prompt
72
+ metadata_by_model = self.metadata_extractor.group_images_by_model(metadata_list)
73
+ metadata_by_prompt = self.metadata_extractor.group_images_by_prompt(metadata_list)
74
+
75
+ return metadata_by_model, metadata_by_prompt
76
+
77
+ def evaluate_images(self, image_files, progress=None):
78
+ """
79
+ Evaluate a list of image files using all metrics.
80
+
81
+ Args:
82
+ image_files: list of image file paths
83
+ progress: optional gradio Progress object
84
+
85
+ Returns:
86
+ dict: evaluation results by image path
87
+ """
88
+ results = {}
89
+
90
+ total_files = len(image_files)
91
+ for i, img_path in enumerate(image_files):
92
+ if progress:
93
+ progress(i / total_files, f"Evaluating image {i+1}/{total_files}")
94
+
95
+ # Get metadata if available
96
+ metadata = self.metadata_cache.get(img_path, {})
97
+ prompt = metadata.get('prompt', '')
98
+
99
+ # Calculate technical metrics
100
+ tech_metrics = self.technical_metrics.calculate_all_metrics(img_path)
101
+
102
+ # Calculate aesthetic metrics
103
+ aesthetic_metrics = self.aesthetic_metrics.calculate_all_metrics(img_path, prompt)
104
+
105
+ # Combine results
106
+ combined_metrics = {**tech_metrics, **aesthetic_metrics}
107
+
108
+ # Store results
109
+ results[img_path] = combined_metrics
110
+
111
+ return results
112
+
113
+ def compare_models(self, evaluation_results, metadata_by_model):
114
+ """
115
+ Compare different models based on evaluation results.
116
+
117
+ Args:
118
+ evaluation_results: dictionary with image paths as keys and metrics as values
119
+ metadata_by_model: dictionary with model names as keys and lists of image paths as values
120
+
121
+ Returns:
122
+ tuple: (comparison_df, visualizations)
123
+ """
124
+ # Group results by model
125
+ results_by_model = {}
126
+ for model, image_paths in metadata_by_model.items():
127
+ model_results = [evaluation_results[img] for img in image_paths if img in evaluation_results]
128
+ results_by_model[model] = model_results
129
+
130
+ # Compare models
131
+ comparison = self.aggregator.compare_models(results_by_model)
132
+
133
+ # Create comparison dataframe
134
+ comparison_df = self.aggregator.create_comparison_dataframe(comparison)
135
+
136
+ # Store current comparison
137
+ self.current_comparison = comparison_df
138
+
139
+ # Create visualizations
140
+ visualizations = {}
141
+
142
+ # Create heatmap
143
+ heatmap_path = self.visualizer.plot_heatmap(comparison_df)
144
+ visualizations['Model Comparison Heatmap'] = heatmap_path
145
+
146
+ # Create radar chart for key metrics
147
+ key_metrics = ['aesthetic_score', 'sharpness', 'noise', 'contrast', 'color_harmony', 'prompt_similarity']
148
+ available_metrics = [m for m in key_metrics if m in comparison_df.columns]
149
+ if available_metrics:
150
+ radar_path = self.visualizer.plot_radar_chart(comparison_df, available_metrics)
151
+ visualizations['Model Comparison Radar Chart'] = radar_path
152
+
153
+ # Create bar charts for important metrics
154
+ for metric in ['overall_score', 'aesthetic_score', 'prompt_similarity']:
155
+ if metric in comparison_df.columns:
156
+ bar_path = self.visualizer.plot_metric_comparison(comparison_df, metric)
157
+ visualizations[f'{metric} Comparison'] = bar_path
158
+
159
+ return comparison_df, visualizations
160
+
161
+ def export_results(self, format='csv'):
162
+ """
163
+ Export current comparison results.
164
+
165
+ Args:
166
+ format: export format ('csv', 'excel', or 'html')
167
+
168
+ Returns:
169
+ str: path to exported file
170
+ """
171
+ if self.current_comparison is not None:
172
+ return self.visualizer.export_comparison_table(self.current_comparison, format)
173
+ return None
174
+
175
+ def generate_report(self, comparison_df, visualizations):
176
+ """
177
+ Generate a comprehensive HTML report.
178
+
179
+ Args:
180
+ comparison_df: pandas DataFrame with comparison data
181
+ visualizations: dictionary of visualization paths
182
+
183
+ Returns:
184
+ str: path to HTML report
185
+ """
186
+ metrics_list = comparison_df.columns.tolist()
187
+ return self.visualizer.generate_html_report(comparison_df, visualizations, metrics_list)
188
+
189
+
190
+ # Create Gradio interface
191
+ def create_interface():
192
+ """Create and configure the Gradio interface."""
193
+
194
+ # Initialize evaluator
195
+ evaluator = ImageEvaluator()
196
+
197
+ # Track state
198
+ state = {
199
+ 'uploaded_images': [],
200
+ 'metadata_by_model': {},
201
+ 'metadata_by_prompt': {},
202
+ 'evaluation_results': {},
203
+ 'comparison_df': None,
204
+ 'visualizations': {},
205
+ 'report_path': None
206
+ }
207
+
208
+ def upload_images(files, progress=gr.Progress()):
209
+ """Handle image upload and processing."""
210
+ # Reset state
211
+ state['uploaded_images'] = []
212
+ state['metadata_by_model'] = {}
213
+ state['metadata_by_prompt'] = {}
214
+ state['evaluation_results'] = {}
215
+ state['comparison_df'] = None
216
+ state['visualizations'] = {}
217
+ state['report_path'] = None
218
+
219
+ # Process uploaded files
220
+ image_paths = [f.name for f in files]
221
+ state['uploaded_images'] = image_paths
222
+
223
+ # Extract metadata and group images
224
+ progress(0, "Extracting metadata...")
225
+ metadata_by_model, metadata_by_prompt = evaluator.process_images(image_paths, progress)
226
+ state['metadata_by_model'] = metadata_by_model
227
+ state['metadata_by_prompt'] = metadata_by_prompt
228
+
229
+ # Create model summary
230
+ model_summary = []
231
+ for model, images in metadata_by_model.items():
232
+ model_summary.append(f"- {model}: {len(images)} images")
233
+
234
+ # Create prompt summary
235
+ prompt_summary = []
236
+ for prompt, images in metadata_by_prompt.items():
237
+ prompt_summary.append(f"- {prompt}: {len(images)} images")
238
+
239
+ return (
240
+ f"Processed {len(image_paths)} images.\n\n"
241
+ f"Found {len(metadata_by_model)} models:\n" + "\n".join(model_summary) + "\n\n"
242
+ f"Found {len(metadata_by_prompt)} unique prompts."
243
+ )
244
+
245
+ def evaluate_images(progress=gr.Progress()):
246
+ """Evaluate all uploaded images."""
247
+ if not state['uploaded_images']:
248
+ return "No images uploaded. Please upload images first."
249
+
250
+ # Evaluate images
251
+ progress(0, "Evaluating images...")
252
+ evaluation_results = evaluator.evaluate_images(state['uploaded_images'], progress)
253
+ state['evaluation_results'] = evaluation_results
254
+
255
+ return f"Evaluated {len(evaluation_results)} images with all metrics."
256
+
257
+ def compare_models():
258
+ """Compare models based on evaluation results."""
259
+ if not state['evaluation_results'] or not state['metadata_by_model']:
260
+ return "No evaluation results available. Please evaluate images first.", None, None
261
+
262
+ # Compare models
263
+ comparison_df, visualizations = evaluator.compare_models(
264
+ state['evaluation_results'], state['metadata_by_model']
265
+ )
266
+ state['comparison_df'] = comparison_df
267
+ state['visualizations'] = visualizations
268
+
269
+ # Generate report
270
+ report_path = evaluator.generate_report(comparison_df, visualizations)
271
+ state['report_path'] = report_path
272
+
273
+ # Get visualization paths
274
+ heatmap_path = visualizations.get('Model Comparison Heatmap')
275
+ radar_path = visualizations.get('Model Comparison Radar Chart')
276
+ overall_score_path = visualizations.get('overall_score Comparison')
277
+
278
+ # Convert DataFrame to markdown for display
279
+ df_markdown = comparison_df.to_markdown()
280
+
281
+ return df_markdown, heatmap_path, radar_path
282
+
283
+ def export_results(format):
284
+ """Export results in the specified format."""
285
+ if state['comparison_df'] is None:
286
+ return "No comparison results available. Please compare models first."
287
+
288
+ export_path = evaluator.export_results(format)
289
+ if export_path:
290
+ return f"Results exported to {export_path}"
291
+ else:
292
+ return "Failed to export results."
293
+
294
+ def view_report():
295
+ """View the generated HTML report."""
296
+ if state['report_path'] and os.path.exists(state['report_path']):
297
+ return state['report_path']
298
+ else:
299
+ return "No report available. Please compare models first."
300
+
301
+ # Create interface
302
+ with gr.Blocks(title="Image Model Evaluator") as interface:
303
+ gr.Markdown("# Image Model Evaluator")
304
+ gr.Markdown("Upload images generated by different AI models to compare their quality and performance.")
305
+
306
+ with gr.Tab("Upload & Process"):
307
+ with gr.Row():
308
+ with gr.Column():
309
+ upload_input = gr.File(
310
+ label="Upload Images (PNG format)",
311
+ file_count="multiple",
312
+ type="file"
313
+ )
314
+ upload_button = gr.Button("Process Uploaded Images")
315
+
316
+ with gr.Column():
317
+ upload_output = gr.Textbox(
318
+ label="Processing Results",
319
+ lines=10,
320
+ interactive=False
321
+ )
322
+
323
+ evaluate_button = gr.Button("Evaluate Images")
324
+ evaluate_output = gr.Textbox(
325
+ label="Evaluation Status",
326
+ lines=2,
327
+ interactive=False
328
+ )
329
+
330
+ with gr.Tab("Compare Models"):
331
+ compare_button = gr.Button("Compare Models")
332
+
333
+ with gr.Row():
334
+ comparison_output = gr.Markdown(
335
+ label="Comparison Results"
336
+ )
337
+
338
+ with gr.Row():
339
+ with gr.Column():
340
+ heatmap_output = gr.Image(
341
+ label="Model Comparison Heatmap",
342
+ interactive=False
343
+ )
344
+
345
+ with gr.Column():
346
+ radar_output = gr.Image(
347
+ label="Model Comparison Radar Chart",
348
+ interactive=False
349
+ )
350
+
351
+ with gr.Tab("Export & Report"):
352
+ with gr.Row():
353
+ with gr.Column():
354
+ export_format = gr.Radio(
355
+ label="Export Format",
356
+ choices=["csv", "excel", "html"],
357
+ value="csv"
358
+ )
359
+ export_button = gr.Button("Export Results")
360
+ export_output = gr.Textbox(
361
+ label="Export Status",
362
+ lines=2,
363
+ interactive=False
364
+ )
365
+
366
+ with gr.Column():
367
+ report_button = gr.Button("View Full Report")
368
+ report_output = gr.HTML(
369
+ label="Full Report"
370
+ )
371
+
372
+ # Set up event handlers
373
+ upload_button.click(
374
+ upload_images,
375
+ inputs=[upload_input],
376
+ outputs=[upload_output]
377
+ )
378
+
379
+ evaluate_button.click(
380
+ evaluate_images,
381
+ inputs=[],
382
+ outputs=[evaluate_output]
383
+ )
384
+
385
+ compare_button.click(
386
+ compare_models,
387
+ inputs=[],
388
+ outputs=[comparison_output, heatmap_output, radar_output]
389
+ )
390
+
391
+ export_button.click(
392
+ export_results,
393
+ inputs=[export_format],
394
+ outputs=[export_output]
395
+ )
396
+
397
+ report_button.click(
398
+ view_report,
399
+ inputs=[],
400
+ outputs=[report_output]
401
+ )
402
+
403
+ return interface
404
+
405
+
406
+ # Launch the application
407
+ if __name__ == "__main__":
408
+ interface = create_interface()
409
+ interface.launch(share=True)
modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Module initialization file for the Image Evaluator tool.
3
+ """
modules/aesthetic_metrics.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Aesthetic metrics for image quality assessment using AI models.
3
+ These metrics evaluate subjective aspects of images like aesthetic appeal, composition, etc.
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image
9
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
10
+ import clip
11
+ from torchvision import transforms
12
+
13
+
14
+ class AestheticMetrics:
15
+ """Class for computing aesthetic image quality metrics using AI models."""
16
+
17
+ def __init__(self):
18
+ """Initialize models for aesthetic evaluation."""
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ self._initialize_models()
21
+
22
+ def _initialize_models(self):
23
+ """Initialize all required models."""
24
+ # Initialize CLIP model for text-image similarity
25
+ try:
26
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
27
+ self.clip_loaded = True
28
+ except Exception as e:
29
+ print(f"Warning: Could not load CLIP model: {e}")
30
+ self.clip_loaded = False
31
+
32
+ # Initialize aesthetic predictor model (LAION Aesthetic Predictor v2)
33
+ try:
34
+ self.aesthetic_model_name = "cafeai/cafe_aesthetic"
35
+ self.aesthetic_extractor = AutoFeatureExtractor.from_pretrained(self.aesthetic_model_name)
36
+ self.aesthetic_model = AutoModelForImageClassification.from_pretrained(self.aesthetic_model_name)
37
+ self.aesthetic_model.to(self.device)
38
+ self.aesthetic_loaded = True
39
+ except Exception as e:
40
+ print(f"Warning: Could not load aesthetic model: {e}")
41
+ self.aesthetic_loaded = False
42
+
43
+ # Initialize transforms for preprocessing
44
+ self.transform = transforms.Compose([
45
+ transforms.Resize(256),
46
+ transforms.CenterCrop(224),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
49
+ ])
50
+
51
+ def calculate_aesthetic_score(self, image_path):
52
+ """
53
+ Calculate aesthetic score using a pre-trained model.
54
+
55
+ Args:
56
+ image_path: path to the image file
57
+
58
+ Returns:
59
+ float: aesthetic score between 0 and 10
60
+ """
61
+ if not self.aesthetic_loaded:
62
+ return 5.0 # Default middle score if model not loaded
63
+
64
+ try:
65
+ image = Image.open(image_path).convert('RGB')
66
+ inputs = self.aesthetic_extractor(images=image, return_tensors="pt").to(self.device)
67
+
68
+ with torch.no_grad():
69
+ outputs = self.aesthetic_model(**inputs)
70
+
71
+ # Get predicted class probabilities
72
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
73
+
74
+ # Calculate weighted score (0-10 scale)
75
+ score_weights = torch.tensor([i for i in range(10)]).to(self.device).float()
76
+ aesthetic_score = torch.sum(probs * score_weights).item()
77
+
78
+ return aesthetic_score
79
+ except Exception as e:
80
+ print(f"Error calculating aesthetic score: {e}")
81
+ return 5.0
82
+
83
+ def calculate_composition_score(self, image_path):
84
+ """
85
+ Estimate composition quality using rule of thirds and symmetry analysis.
86
+
87
+ Args:
88
+ image_path: path to the image file
89
+
90
+ Returns:
91
+ float: composition score between 0 and 10
92
+ """
93
+ try:
94
+ # Load image
95
+ image = Image.open(image_path).convert('RGB')
96
+ img_array = np.array(image)
97
+
98
+ # Calculate rule of thirds score
99
+ h, w = img_array.shape[:2]
100
+ third_h, third_w = h // 3, w // 3
101
+
102
+ # Define rule of thirds points
103
+ thirds_points = [
104
+ (third_w, third_h), (2*third_w, third_h),
105
+ (third_w, 2*third_h), (2*third_w, 2*third_h)
106
+ ]
107
+
108
+ # Calculate edge detection to find important elements
109
+ gray = np.mean(img_array, axis=2).astype(np.uint8)
110
+ edges = np.abs(np.diff(gray, axis=0, append=0)) + np.abs(np.diff(gray, axis=1, append=0))
111
+
112
+ # Calculate score based on edge concentration near thirds points
113
+ thirds_score = 0
114
+ for px, py in thirds_points:
115
+ # Get region around thirds point
116
+ region = edges[max(0, py-50):min(h, py+50), max(0, px-50):min(w, px+50)]
117
+ thirds_score += np.mean(region)
118
+
119
+ # Normalize score
120
+ thirds_score = min(10, thirds_score / 100)
121
+
122
+ # Calculate symmetry score
123
+ flipped = np.fliplr(img_array)
124
+ symmetry_diff = np.mean(np.abs(img_array.astype(float) - flipped.astype(float)))
125
+ symmetry_score = 10 * (1 - symmetry_diff / 255)
126
+
127
+ # Combine scores (weighted average)
128
+ composition_score = 0.7 * thirds_score + 0.3 * symmetry_score
129
+
130
+ return min(10, max(0, composition_score))
131
+ except Exception as e:
132
+ print(f"Error calculating composition score: {e}")
133
+ return 5.0
134
+
135
+ def calculate_color_harmony(self, image_path):
136
+ """
137
+ Calculate color harmony score based on color theory.
138
+
139
+ Args:
140
+ image_path: path to the image file
141
+
142
+ Returns:
143
+ float: color harmony score between 0 and 10
144
+ """
145
+ try:
146
+ # Load image
147
+ image = Image.open(image_path).convert('RGB')
148
+ img_array = np.array(image)
149
+
150
+ # Convert to HSV for better color analysis
151
+ hsv = np.array(image.convert('HSV'))
152
+
153
+ # Extract hue channel and create histogram
154
+ hue = hsv[:,:,0].flatten()
155
+ hist, _ = np.histogram(hue, bins=36, range=(0, 255))
156
+ hist = hist / np.sum(hist)
157
+
158
+ # Calculate entropy of hue distribution
159
+ entropy = -np.sum(hist * np.log2(hist + 1e-10))
160
+
161
+ # Calculate complementary color usage
162
+ complementary_score = 0
163
+ for i in range(18):
164
+ complementary_i = (i + 18) % 36
165
+ complementary_score += min(hist[i], hist[complementary_i])
166
+
167
+ # Calculate analogous color usage
168
+ analogous_score = 0
169
+ for i in range(36):
170
+ analogous_i1 = (i + 1) % 36
171
+ analogous_i2 = (i + 35) % 36
172
+ analogous_score += min(hist[i], max(hist[analogous_i1], hist[analogous_i2]))
173
+
174
+ # Calculate saturation variance as a measure of color interest
175
+ saturation = hsv[:,:,1].flatten()
176
+ saturation_variance = np.var(saturation)
177
+
178
+ # Combine metrics into final score
179
+ harmony_score = (
180
+ 3 * (1 - min(1, entropy/5)) + # Lower entropy is better for harmony
181
+ 3 * complementary_score + # Complementary colors
182
+ 2 * analogous_score + # Analogous colors
183
+ 2 * min(1, saturation_variance/2000) # Saturation variance
184
+ )
185
+
186
+ return min(10, max(0, harmony_score))
187
+ except Exception as e:
188
+ print(f"Error calculating color harmony: {e}")
189
+ return 5.0
190
+
191
+ def calculate_prompt_similarity(self, image_path, prompt):
192
+ """
193
+ Calculate similarity between image and text prompt using CLIP.
194
+
195
+ Args:
196
+ image_path: path to the image file
197
+ prompt: text prompt used to generate the image
198
+
199
+ Returns:
200
+ float: similarity score between 0 and 10
201
+ """
202
+ if not self.clip_loaded or not prompt:
203
+ return 5.0 # Default middle score if model not loaded or no prompt
204
+
205
+ try:
206
+ # Load and preprocess image
207
+ image = Image.open(image_path).convert('RGB')
208
+ image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device)
209
+
210
+ # Process text
211
+ text_input = clip.tokenize([prompt]).to(self.device)
212
+
213
+ # Calculate similarity
214
+ with torch.no_grad():
215
+ image_features = self.clip_model.encode_image(image_input)
216
+ text_features = self.clip_model.encode_text(text_input)
217
+
218
+ # Normalize features
219
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
220
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
221
+
222
+ # Calculate similarity
223
+ similarity = (100.0 * image_features @ text_features.T).item()
224
+
225
+ # Convert to 0-10 scale
226
+ return min(10, max(0, similarity / 10))
227
+ except Exception as e:
228
+ print(f"Error calculating prompt similarity: {e}")
229
+ return 5.0
230
+
231
+ def calculate_all_metrics(self, image_path, prompt=None):
232
+ """
233
+ Calculate all aesthetic metrics for an image.
234
+
235
+ Args:
236
+ image_path: path to the image file
237
+ prompt: optional text prompt used to generate the image
238
+
239
+ Returns:
240
+ dict: dictionary with all metric scores
241
+ """
242
+ metrics = {
243
+ 'aesthetic_score': self.calculate_aesthetic_score(image_path),
244
+ 'composition_score': self.calculate_composition_score(image_path),
245
+ 'color_harmony': self.calculate_color_harmony(image_path),
246
+ }
247
+
248
+ # Add prompt similarity if prompt is provided
249
+ if prompt:
250
+ metrics['prompt_similarity'] = self.calculate_prompt_similarity(image_path, prompt)
251
+
252
+ return metrics
modules/aggregator.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for aggregating results from different evaluation metrics.
3
+ """
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+ from collections import defaultdict
8
+
9
+
10
+ class ResultsAggregator:
11
+ """Class for aggregating and analyzing image evaluation results."""
12
+
13
+ def __init__(self):
14
+ """Initialize the aggregator."""
15
+ # Weights for different metric categories
16
+ self.default_weights = {
17
+ # Technical metrics
18
+ 'sharpness': 1.0,
19
+ 'noise': 1.0,
20
+ 'contrast': 1.0,
21
+ 'saturation': 1.0,
22
+ 'entropy': 1.0,
23
+ 'compression_artifacts': 1.0,
24
+ 'dynamic_range': 1.0,
25
+
26
+ # Aesthetic metrics
27
+ 'aesthetic_score': 1.5,
28
+ 'composition_score': 1.2,
29
+ 'color_harmony': 1.2,
30
+
31
+ # Prompt metrics
32
+ 'prompt_similarity': 2.0,
33
+ }
34
+
35
+ # Metrics where lower is better
36
+ self.inverse_metrics = ['noise', 'compression_artifacts']
37
+
38
+ def normalize_metric(self, values, metric_name):
39
+ """
40
+ Normalize metric values to 0-10 scale.
41
+
42
+ Args:
43
+ values: list of metric values
44
+ metric_name: name of the metric
45
+
46
+ Returns:
47
+ list: normalized values
48
+ """
49
+ if not values:
50
+ return []
51
+
52
+ # For metrics where lower is better, invert the values
53
+ if metric_name in self.inverse_metrics:
54
+ values = [max(values) - v + min(values) for v in values]
55
+
56
+ # Normalize to 0-10 scale
57
+ min_val = min(values)
58
+ max_val = max(values)
59
+
60
+ if max_val == min_val:
61
+ return [5.0] * len(values) # Default to middle value if all values are the same
62
+
63
+ return [10 * (v - min_val) / (max_val - min_val) for v in values]
64
+
65
+ def aggregate_model_results(self, model_results, custom_weights=None):
66
+ """
67
+ Aggregate results for a single model across multiple images.
68
+
69
+ Args:
70
+ model_results: list of metric dictionaries for images from the same model
71
+ custom_weights: optional dictionary of custom weights for metrics
72
+
73
+ Returns:
74
+ dict: aggregated metrics
75
+ """
76
+ if not model_results:
77
+ return {}
78
+
79
+ # Use default weights if custom weights not provided
80
+ weights = custom_weights if custom_weights else self.default_weights
81
+
82
+ # Initialize aggregated results
83
+ aggregated = {}
84
+
85
+ # Collect all metrics
86
+ all_metrics = set()
87
+ for result in model_results:
88
+ all_metrics.update(result.keys())
89
+
90
+ # Aggregate each metric
91
+ for metric in all_metrics:
92
+ # Skip non-numeric metrics
93
+ values = [result.get(metric) for result in model_results if metric in result
94
+ and isinstance(result[metric], (int, float))]
95
+
96
+ if values:
97
+ aggregated[metric] = {
98
+ 'mean': np.mean(values),
99
+ 'median': np.median(values),
100
+ 'std': np.std(values),
101
+ 'min': np.min(values),
102
+ 'max': np.max(values),
103
+ 'count': len(values)
104
+ }
105
+
106
+ # Calculate overall score
107
+ score_components = []
108
+ weight_sum = 0
109
+
110
+ for metric, stats in aggregated.items():
111
+ if metric in weights:
112
+ # Normalize the mean value to 0-10 scale
113
+ normalized_value = stats['mean']
114
+ if metric in self.inverse_metrics:
115
+ # For metrics where lower is better, invert the scale
116
+ normalized_value = 10 - normalized_value
117
+
118
+ # Apply weight
119
+ weight = weights[metric]
120
+ score_components.append(normalized_value * weight)
121
+ weight_sum += weight
122
+
123
+ # Calculate weighted average
124
+ if weight_sum > 0:
125
+ aggregated['overall_score'] = sum(score_components) / weight_sum
126
+ else:
127
+ aggregated['overall_score'] = 5.0 # Default middle score
128
+
129
+ return aggregated
130
+
131
+ def compare_models(self, model_results_dict, custom_weights=None):
132
+ """
133
+ Compare results across different models.
134
+
135
+ Args:
136
+ model_results_dict: dictionary with model names as keys and lists of results as values
137
+ custom_weights: optional dictionary of custom weights for metrics
138
+
139
+ Returns:
140
+ dict: comparison results
141
+ """
142
+ # Aggregate results for each model
143
+ aggregated_results = {}
144
+ for model_name, results in model_results_dict.items():
145
+ aggregated_results[model_name] = self.aggregate_model_results(results, custom_weights)
146
+
147
+ # Extract key metrics for comparison
148
+ comparison = {}
149
+ for model_name, agg_results in aggregated_results.items():
150
+ model_comparison = {
151
+ 'overall_score': agg_results.get('overall_score', 5.0)
152
+ }
153
+
154
+ # Add mean values of all metrics
155
+ for metric, stats in agg_results.items():
156
+ if metric != 'overall_score' and isinstance(stats, dict) and 'mean' in stats:
157
+ model_comparison[f"{metric}"] = stats['mean']
158
+
159
+ comparison[model_name] = model_comparison
160
+
161
+ return comparison
162
+
163
+ def analyze_by_prompt(self, results_by_prompt, custom_weights=None):
164
+ """
165
+ Analyze results grouped by prompt.
166
+
167
+ Args:
168
+ results_by_prompt: dictionary with prompts as keys and dictionaries of model results as values
169
+ custom_weights: optional dictionary of custom weights for metrics
170
+
171
+ Returns:
172
+ dict: analysis results by prompt
173
+ """
174
+ prompt_analysis = {}
175
+
176
+ for prompt, model_results in results_by_prompt.items():
177
+ # Compare models for this prompt
178
+ prompt_comparison = self.compare_models(model_results, custom_weights)
179
+
180
+ # Find best model for this prompt
181
+ best_model = None
182
+ best_score = -1
183
+
184
+ for model, metrics in prompt_comparison.items():
185
+ score = metrics.get('overall_score', 0)
186
+ if score > best_score:
187
+ best_score = score
188
+ best_model = model
189
+
190
+ prompt_analysis[prompt] = {
191
+ 'model_comparison': prompt_comparison,
192
+ 'best_model': best_model,
193
+ 'best_score': best_score
194
+ }
195
+
196
+ return prompt_analysis
197
+
198
+ def create_comparison_dataframe(self, comparison_results):
199
+ """
200
+ Create a pandas DataFrame from comparison results.
201
+
202
+ Args:
203
+ comparison_results: dictionary with model names as keys and metric dictionaries as values
204
+
205
+ Returns:
206
+ pandas.DataFrame: comparison table
207
+ """
208
+ # Convert to DataFrame
209
+ df = pd.DataFrame.from_dict(comparison_results, orient='index')
210
+
211
+ # Sort by overall score
212
+ if 'overall_score' in df.columns:
213
+ df = df.sort_values('overall_score', ascending=False)
214
+
215
+ return df
modules/metadata_extractor.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for extracting metadata from image files, particularly focusing on
3
+ Stable Diffusion metadata from PNG files.
4
+ """
5
+
6
+ import io
7
+ from PIL import Image, PngImagePlugin
8
+ import re
9
+
10
+
11
+ class MetadataExtractor:
12
+ """Class for extracting and parsing metadata from images."""
13
+
14
+ @staticmethod
15
+ def extract_metadata(image_path):
16
+ """
17
+ Extract metadata from an image file.
18
+
19
+ Args:
20
+ image_path: path to the image file
21
+
22
+ Returns:
23
+ dict: dictionary with extracted metadata
24
+ """
25
+ try:
26
+ # Open image with PIL
27
+ image = Image.open(image_path)
28
+
29
+ # Extract metadata from PNG info
30
+ metadata_text = image.info.get("parameters", "")
31
+
32
+ # Parse the metadata
33
+ parsed_metadata = MetadataExtractor.parse_metadata(metadata_text)
34
+
35
+ # Add basic image info
36
+ parsed_metadata.update({
37
+ 'width': image.width,
38
+ 'height': image.height,
39
+ 'format': image.format,
40
+ 'mode': image.mode,
41
+ })
42
+
43
+ return parsed_metadata
44
+ except Exception as e:
45
+ print(f"Error extracting metadata from {image_path}: {e}")
46
+ return {'error': str(e)}
47
+
48
+ @staticmethod
49
+ def parse_metadata(metadata_text):
50
+ """
51
+ Parse Stable Diffusion metadata text into structured data.
52
+
53
+ Args:
54
+ metadata_text: raw metadata text from image
55
+
56
+ Returns:
57
+ dict: structured metadata
58
+ """
59
+ if not metadata_text:
60
+ return {'raw_text': ''}
61
+
62
+ result = {'raw_text': metadata_text}
63
+
64
+ # Extract prompt
65
+ prompt_end = metadata_text.find("Negative prompt:")
66
+ if prompt_end > 0:
67
+ result['prompt'] = metadata_text[:prompt_end].strip()
68
+ negative_prompt_end = metadata_text.find("\n", prompt_end)
69
+ if negative_prompt_end > 0:
70
+ result['negative_prompt'] = metadata_text[prompt_end + len("Negative prompt:"):negative_prompt_end].strip()
71
+ else:
72
+ result['prompt'] = metadata_text.strip()
73
+
74
+ # Extract model name
75
+ model_match = re.search(r'Model: ([^,\n]+)', metadata_text)
76
+ if model_match:
77
+ result['model'] = model_match.group(1).strip()
78
+
79
+ # Extract other parameters
80
+ params = {
81
+ 'steps': r'Steps: (\d+)',
82
+ 'sampler': r'Sampler: ([^,\n]+)',
83
+ 'cfg_scale': r'CFG scale: ([^,\n]+)',
84
+ 'seed': r'Seed: ([^,\n]+)',
85
+ 'size': r'Size: ([^,\n]+)',
86
+ 'model_hash': r'Model hash: ([^,\n]+)',
87
+ }
88
+
89
+ for key, pattern in params.items():
90
+ match = re.search(pattern, metadata_text)
91
+ if match:
92
+ result[key] = match.group(1).strip()
93
+
94
+ return result
95
+
96
+ @staticmethod
97
+ def group_images_by_model(metadata_list):
98
+ """
99
+ Group images by model name.
100
+
101
+ Args:
102
+ metadata_list: list of (image_path, metadata) tuples
103
+
104
+ Returns:
105
+ dict: dictionary with model names as keys and lists of image paths as values
106
+ """
107
+ result = {}
108
+
109
+ for image_path, metadata in metadata_list:
110
+ model = metadata.get('model', 'unknown')
111
+ if model not in result:
112
+ result[model] = []
113
+ result[model].append(image_path)
114
+
115
+ return result
116
+
117
+ @staticmethod
118
+ def group_images_by_prompt(metadata_list):
119
+ """
120
+ Group images by prompt.
121
+
122
+ Args:
123
+ metadata_list: list of (image_path, metadata) tuples
124
+
125
+ Returns:
126
+ dict: dictionary with prompts as keys and lists of image paths as values
127
+ """
128
+ result = {}
129
+
130
+ for image_path, metadata in metadata_list:
131
+ prompt = metadata.get('prompt', 'unknown')
132
+ # Use first 50 chars as key to avoid extremely long keys
133
+ prompt_key = prompt[:50] + ('...' if len(prompt) > 50 else '')
134
+ if prompt_key not in result:
135
+ result[prompt_key] = []
136
+ result[prompt_key].append((image_path, metadata.get('model', 'unknown')))
137
+
138
+ return result
139
+
140
+ @staticmethod
141
+ def update_metadata(image_path, new_metadata, output_path=None):
142
+ """
143
+ Update metadata in an image file.
144
+
145
+ Args:
146
+ image_path: path to the input image file
147
+ new_metadata: new metadata text to write
148
+ output_path: path to save the updated image (if None, overwrites input)
149
+
150
+ Returns:
151
+ bool: True if successful, False otherwise
152
+ """
153
+ try:
154
+ # Open image with PIL
155
+ image = Image.open(image_path)
156
+
157
+ # Create a PngInfo object to store metadata
158
+ pnginfo = PngImagePlugin.PngInfo()
159
+ pnginfo.add_text("parameters", new_metadata)
160
+
161
+ # Save the image with the updated metadata
162
+ save_path = output_path if output_path else image_path
163
+ image.save(save_path, format="PNG", pnginfo=pnginfo)
164
+
165
+ return True
166
+ except Exception as e:
167
+ print(f"Error updating metadata: {e}")
168
+ return False
modules/technical_metrics.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Technical metrics for image quality assessment without using AI models.
3
+ These metrics evaluate basic technical aspects of images like sharpness, noise, etc.
4
+ """
5
+
6
+ import numpy as np
7
+ import cv2
8
+ from skimage.metrics import structural_similarity as ssim
9
+ from skimage.measure import shannon_entropy
10
+ from PIL import Image, ImageStat
11
+
12
+
13
+ class TechnicalMetrics:
14
+ """Class for computing technical image quality metrics."""
15
+
16
+ @staticmethod
17
+ def calculate_sharpness(image_array):
18
+ """
19
+ Calculate image sharpness using Laplacian variance.
20
+ Higher values indicate sharper images.
21
+
22
+ Args:
23
+ image_array: numpy array of the image
24
+
25
+ Returns:
26
+ float: sharpness score
27
+ """
28
+ if len(image_array.shape) == 3:
29
+ gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
30
+ else:
31
+ gray = image_array
32
+
33
+ # Calculate variance of Laplacian
34
+ return cv2.Laplacian(gray, cv2.CV_64F).var()
35
+
36
+ @staticmethod
37
+ def calculate_noise(image_array):
38
+ """
39
+ Estimate image noise level.
40
+ Lower values indicate less noisy images.
41
+
42
+ Args:
43
+ image_array: numpy array of the image
44
+
45
+ Returns:
46
+ float: noise level
47
+ """
48
+ if len(image_array.shape) == 3:
49
+ gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
50
+ else:
51
+ gray = image_array
52
+
53
+ # Estimate noise using median filter difference
54
+ denoised = cv2.medianBlur(gray, 5)
55
+ diff = cv2.absdiff(gray, denoised)
56
+ return np.mean(diff)
57
+
58
+ @staticmethod
59
+ def calculate_contrast(image_array):
60
+ """
61
+ Calculate image contrast.
62
+ Higher values indicate higher contrast.
63
+
64
+ Args:
65
+ image_array: numpy array of the image
66
+
67
+ Returns:
68
+ float: contrast score
69
+ """
70
+ if len(image_array.shape) == 3:
71
+ gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
72
+ else:
73
+ gray = image_array
74
+
75
+ # Calculate standard deviation as a measure of contrast
76
+ return np.std(gray)
77
+
78
+ @staticmethod
79
+ def calculate_saturation(image_array):
80
+ """
81
+ Calculate color saturation.
82
+ Higher values indicate more saturated colors.
83
+
84
+ Args:
85
+ image_array: numpy array of the image
86
+
87
+ Returns:
88
+ float: saturation score
89
+ """
90
+ if len(image_array.shape) != 3:
91
+ return 0.0 # Grayscale images have no saturation
92
+
93
+ # Convert to HSV and calculate mean saturation
94
+ hsv = cv2.cvtColor(image_array, cv2.COLOR_RGB2HSV)
95
+ return np.mean(hsv[:, :, 1])
96
+
97
+ @staticmethod
98
+ def calculate_entropy(image_array):
99
+ """
100
+ Calculate image entropy as a measure of detail/complexity.
101
+ Higher values indicate more complex images.
102
+
103
+ Args:
104
+ image_array: numpy array of the image
105
+
106
+ Returns:
107
+ float: entropy score
108
+ """
109
+ if len(image_array.shape) == 3:
110
+ gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
111
+ else:
112
+ gray = image_array
113
+
114
+ return shannon_entropy(gray)
115
+
116
+ @staticmethod
117
+ def detect_compression_artifacts(image_array):
118
+ """
119
+ Detect JPEG compression artifacts.
120
+ Higher values indicate more artifacts.
121
+
122
+ Args:
123
+ image_array: numpy array of the image
124
+
125
+ Returns:
126
+ float: artifact score
127
+ """
128
+ if len(image_array.shape) == 3:
129
+ gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
130
+ else:
131
+ gray = image_array
132
+
133
+ # Apply edge detection to find blocky artifacts
134
+ edges = cv2.Canny(gray, 100, 200)
135
+ return np.mean(edges) / 255.0
136
+
137
+ @staticmethod
138
+ def calculate_dynamic_range(image_array):
139
+ """
140
+ Calculate dynamic range of the image.
141
+ Higher values indicate better use of available intensity range.
142
+
143
+ Args:
144
+ image_array: numpy array of the image
145
+
146
+ Returns:
147
+ float: dynamic range score
148
+ """
149
+ if len(image_array.shape) == 3:
150
+ gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
151
+ else:
152
+ gray = image_array
153
+
154
+ p1 = np.percentile(gray, 1)
155
+ p99 = np.percentile(gray, 99)
156
+ return (p99 - p1) / 255.0
157
+
158
+ @staticmethod
159
+ def calculate_all_metrics(image_path):
160
+ """
161
+ Calculate all technical metrics for an image.
162
+
163
+ Args:
164
+ image_path: path to the image file
165
+
166
+ Returns:
167
+ dict: dictionary with all metric scores
168
+ """
169
+ # Load image with PIL for metadata
170
+ pil_image = Image.open(image_path)
171
+
172
+ # Convert to numpy array for OpenCV processing
173
+ image_array = np.array(pil_image)
174
+
175
+ # Calculate all metrics
176
+ metrics = {
177
+ 'sharpness': TechnicalMetrics.calculate_sharpness(image_array),
178
+ 'noise': TechnicalMetrics.calculate_noise(image_array),
179
+ 'contrast': TechnicalMetrics.calculate_contrast(image_array),
180
+ 'saturation': TechnicalMetrics.calculate_saturation(image_array),
181
+ 'entropy': TechnicalMetrics.calculate_entropy(image_array),
182
+ 'compression_artifacts': TechnicalMetrics.detect_compression_artifacts(image_array),
183
+ 'dynamic_range': TechnicalMetrics.calculate_dynamic_range(image_array),
184
+ 'resolution': f"{pil_image.width}x{pil_image.height}",
185
+ 'aspect_ratio': pil_image.width / pil_image.height if pil_image.height > 0 else 0,
186
+ 'file_size_kb': pil_image.fp.tell() / 1024 if hasattr(pil_image.fp, 'tell') else 0,
187
+ }
188
+
189
+ return metrics
modules/visualizer.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for visualizing image evaluation results and creating comparison tables.
3
+ """
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from matplotlib.colors import LinearSegmentedColormap
10
+ import os
11
+ import io
12
+ from PIL import Image
13
+ import base64
14
+
15
+
16
+ class Visualizer:
17
+ """Class for visualizing image evaluation results."""
18
+
19
+ def __init__(self, output_dir='./results'):
20
+ """
21
+ Initialize visualizer with output directory.
22
+
23
+ Args:
24
+ output_dir: directory to save visualization results
25
+ """
26
+ self.output_dir = output_dir
27
+ os.makedirs(output_dir, exist_ok=True)
28
+
29
+ # Set up color schemes
30
+ self.setup_colors()
31
+
32
+ def setup_colors(self):
33
+ """Set up color schemes for visualizations."""
34
+ # Custom colormap for heatmaps
35
+ self.cmap = LinearSegmentedColormap.from_list(
36
+ 'custom_cmap', ['#FF5E5B', '#FFED66', '#00CEFF', '#0089BA', '#008F7A'], N=256
37
+ )
38
+
39
+ # Color palette for bar charts
40
+ self.palette = sns.color_palette("viridis", 10)
41
+
42
+ # Set Seaborn style
43
+ sns.set_style("whitegrid")
44
+
45
+ def create_comparison_table(self, results_dict, metrics_list=None):
46
+ """
47
+ Create a comparison table from evaluation results.
48
+
49
+ Args:
50
+ results_dict: dictionary with model names as keys and evaluation results as values
51
+ metrics_list: list of metrics to include in the table (if None, include all)
52
+
53
+ Returns:
54
+ pandas.DataFrame: comparison table
55
+ """
56
+ # Initialize empty dataframe
57
+ df = pd.DataFrame()
58
+
59
+ # Process each model's results
60
+ for model_name, model_results in results_dict.items():
61
+ # Create a row for this model
62
+ model_row = {'Model': model_name}
63
+
64
+ # Add metrics to the row
65
+ for metric_name, metric_value in model_results.items():
66
+ if metrics_list is None or metric_name in metrics_list:
67
+ # Format numeric values to 2 decimal places
68
+ if isinstance(metric_value, (int, float)):
69
+ model_row[metric_name] = round(metric_value, 2)
70
+ else:
71
+ model_row[metric_name] = metric_value
72
+
73
+ # Append to dataframe
74
+ df = pd.concat([df, pd.DataFrame([model_row])], ignore_index=True)
75
+
76
+ # Set Model as index
77
+ if not df.empty:
78
+ df.set_index('Model', inplace=True)
79
+
80
+ return df
81
+
82
+ def plot_metric_comparison(self, df, metric_name, title=None, figsize=(10, 6)):
83
+ """
84
+ Create a bar chart comparing models on a specific metric.
85
+
86
+ Args:
87
+ df: pandas DataFrame with comparison data
88
+ metric_name: name of the metric to plot
89
+ title: optional custom title
90
+ figsize: figure size as (width, height)
91
+
92
+ Returns:
93
+ str: path to saved figure
94
+ """
95
+ if metric_name not in df.columns:
96
+ raise ValueError(f"Metric '{metric_name}' not found in dataframe")
97
+
98
+ # Create figure
99
+ plt.figure(figsize=figsize)
100
+
101
+ # Create bar chart
102
+ ax = sns.barplot(x=df.index, y=df[metric_name], palette=self.palette)
103
+
104
+ # Set title and labels
105
+ if title:
106
+ plt.title(title, fontsize=14)
107
+ else:
108
+ plt.title(f"Model Comparison: {metric_name}", fontsize=14)
109
+
110
+ plt.xlabel("Model", fontsize=12)
111
+ plt.ylabel(metric_name, fontsize=12)
112
+
113
+ # Rotate x-axis labels for better readability
114
+ plt.xticks(rotation=45, ha='right')
115
+
116
+ # Add value labels on top of bars
117
+ for i, v in enumerate(df[metric_name]):
118
+ ax.text(i, v + 0.1, str(round(v, 2)), ha='center')
119
+
120
+ plt.tight_layout()
121
+
122
+ # Save figure
123
+ output_path = os.path.join(self.output_dir, f"{metric_name}_comparison.png")
124
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
125
+ plt.close()
126
+
127
+ return output_path
128
+
129
+ def plot_radar_chart(self, df, metrics_list, title=None, figsize=(10, 8)):
130
+ """
131
+ Create a radar chart comparing models across multiple metrics.
132
+
133
+ Args:
134
+ df: pandas DataFrame with comparison data
135
+ metrics_list: list of metrics to include in the radar chart
136
+ title: optional custom title
137
+ figsize: figure size as (width, height)
138
+
139
+ Returns:
140
+ str: path to saved figure
141
+ """
142
+ # Filter metrics that exist in the dataframe
143
+ metrics = [m for m in metrics_list if m in df.columns]
144
+
145
+ if not metrics:
146
+ raise ValueError("None of the specified metrics found in dataframe")
147
+
148
+ # Number of metrics
149
+ N = len(metrics)
150
+
151
+ # Create figure
152
+ fig = plt.figure(figsize=figsize)
153
+ ax = fig.add_subplot(111, polar=True)
154
+
155
+ # Compute angle for each metric
156
+ angles = [n / float(N) * 2 * np.pi for n in range(N)]
157
+ angles += angles[:1] # Close the loop
158
+
159
+ # Plot each model
160
+ for i, model in enumerate(df.index):
161
+ values = df.loc[model, metrics].values.flatten().tolist()
162
+ values += values[:1] # Close the loop
163
+
164
+ # Plot values
165
+ ax.plot(angles, values, linewidth=2, linestyle='solid', label=model, color=self.palette[i % len(self.palette)])
166
+ ax.fill(angles, values, alpha=0.1, color=self.palette[i % len(self.palette)])
167
+
168
+ # Set labels
169
+ plt.xticks(angles[:-1], metrics, size=12)
170
+
171
+ # Set y-axis limits
172
+ ax.set_ylim(0, 10)
173
+
174
+ # Add legend
175
+ plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
176
+
177
+ # Set title
178
+ if title:
179
+ plt.title(title, size=16, y=1.1)
180
+ else:
181
+ plt.title("Model Comparison Across Metrics", size=16, y=1.1)
182
+
183
+ # Save figure
184
+ output_path = os.path.join(self.output_dir, "radar_comparison.png")
185
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
186
+ plt.close()
187
+
188
+ return output_path
189
+
190
+ def plot_heatmap(self, df, title=None, figsize=(12, 8)):
191
+ """
192
+ Create a heatmap of all metrics across models.
193
+
194
+ Args:
195
+ df: pandas DataFrame with comparison data
196
+ title: optional custom title
197
+ figsize: figure size as (width, height)
198
+
199
+ Returns:
200
+ str: path to saved figure
201
+ """
202
+ # Create figure
203
+ plt.figure(figsize=figsize)
204
+
205
+ # Create heatmap
206
+ ax = sns.heatmap(df, annot=True, cmap=self.cmap, fmt=".2f", linewidths=.5)
207
+
208
+ # Set title
209
+ if title:
210
+ plt.title(title, fontsize=16)
211
+ else:
212
+ plt.title("Model Comparison Heatmap", fontsize=16)
213
+
214
+ plt.tight_layout()
215
+
216
+ # Save figure
217
+ output_path = os.path.join(self.output_dir, "comparison_heatmap.png")
218
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
219
+ plt.close()
220
+
221
+ return output_path
222
+
223
+ def plot_prompt_performance(self, prompt_results, metric_name, top_n=5, figsize=(12, 8)):
224
+ """
225
+ Create a grouped bar chart showing model performance on different prompts.
226
+
227
+ Args:
228
+ prompt_results: dictionary with prompts as keys and model results as values
229
+ metric_name: name of the metric to plot
230
+ top_n: number of top prompts to include
231
+ figsize: figure size as (width, height)
232
+
233
+ Returns:
234
+ str: path to saved figure
235
+ """
236
+ # Create dataframe from results
237
+ data = []
238
+ for prompt, models_data in prompt_results.items():
239
+ for model, metrics in models_data.items():
240
+ if metric_name in metrics:
241
+ data.append({
242
+ 'Prompt': prompt,
243
+ 'Model': model,
244
+ metric_name: metrics[metric_name]
245
+ })
246
+
247
+ df = pd.DataFrame(data)
248
+
249
+ if df.empty:
250
+ raise ValueError(f"No data found for metric '{metric_name}'")
251
+
252
+ # Get top N prompts by average metric value
253
+ top_prompts = df.groupby('Prompt')[metric_name].mean().nlargest(top_n).index.tolist()
254
+ df_filtered = df[df['Prompt'].isin(top_prompts)]
255
+
256
+ # Create figure
257
+ plt.figure(figsize=figsize)
258
+
259
+ # Create grouped bar chart
260
+ ax = sns.barplot(x='Prompt', y=metric_name, hue='Model', data=df_filtered, palette=self.palette)
261
+
262
+ # Set title and labels
263
+ plt.title(f"Model Performance by Prompt: {metric_name}", fontsize=14)
264
+ plt.xlabel("Prompt", fontsize=12)
265
+ plt.ylabel(metric_name, fontsize=12)
266
+
267
+ # Rotate x-axis labels for better readability
268
+ plt.xticks(rotation=45, ha='right')
269
+
270
+ # Adjust legend
271
+ plt.legend(title="Model", bbox_to_anchor=(1.05, 1), loc='upper left')
272
+
273
+ plt.tight_layout()
274
+
275
+ # Save figure
276
+ output_path = os.path.join(self.output_dir, f"prompt_performance_{metric_name}.png")
277
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
278
+ plt.close()
279
+
280
+ return output_path
281
+
282
+ def create_image_grid(self, image_paths, titles=None, cols=3, figsize=(15, 15)):
283
+ """
284
+ Create a grid of images for visual comparison.
285
+
286
+ Args:
287
+ image_paths: list of paths to images
288
+ titles: optional list of titles for each image
289
+ cols: number of columns in the grid
290
+ figsize: figure size as (width, height)
291
+
292
+ Returns:
293
+ str: path to saved figure
294
+ """
295
+ # Calculate number of rows needed
296
+ rows = (len(image_paths) + cols - 1) // cols
297
+
298
+ # Create figure
299
+ fig, axes = plt.subplots(rows, cols, figsize=figsize)
300
+ axes = axes.flatten()
301
+
302
+ # Add each image to the grid
303
+ for i, img_path in enumerate(image_paths):
304
+ if i < len(axes):
305
+ try:
306
+ img = Image.open(img_path)
307
+ axes[i].imshow(np.array(img))
308
+
309
+ # Add title if provided
310
+ if titles and i < len(titles):
311
+ axes[i].set_title(titles[i])
312
+
313
+ # Remove axis ticks
314
+ axes[i].set_xticks([])
315
+ axes[i].set_yticks([])
316
+ except Exception as e:
317
+ print(f"Error loading image {img_path}: {e}")
318
+ axes[i].text(0.5, 0.5, f"Error loading image", ha='center', va='center')
319
+ axes[i].set_xticks([])
320
+ axes[i].set_yticks([])
321
+
322
+ # Hide unused subplots
323
+ for j in range(len(image_paths), len(axes)):
324
+ axes[j].axis('off')
325
+
326
+ plt.tight_layout()
327
+
328
+ # Save figure
329
+ output_path = os.path.join(self.output_dir, "image_comparison_grid.png")
330
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
331
+ plt.close()
332
+
333
+ return output_path
334
+
335
+ def export_comparison_table(self, df, format='csv'):
336
+ """
337
+ Export comparison table to file.
338
+
339
+ Args:
340
+ df: pandas DataFrame with comparison data
341
+ format: export format ('csv', 'excel', or 'html')
342
+
343
+ Returns:
344
+ str: path to saved file
345
+ """
346
+ if format == 'csv':
347
+ output_path = os.path.join(self.output_dir, "comparison_table.csv")
348
+ df.to_csv(output_path)
349
+ elif format == 'excel':
350
+ output_path = os.path.join(self.output_dir, "comparison_table.xlsx")
351
+ df.to_excel(output_path)
352
+ elif format == 'html':
353
+ output_path = os.path.join(self.output_dir, "comparison_table.html")
354
+ df.to_html(output_path)
355
+ else:
356
+ raise ValueError(f"Unsupported format: {format}")
357
+
358
+ return output_path
359
+
360
+ def generate_html_report(self, comparison_table, image_paths, metrics_list):
361
+ """
362
+ Generate a comprehensive HTML report with all visualizations.
363
+
364
+ Args:
365
+ comparison_table: pandas DataFrame with comparison data
366
+ image_paths: dictionary of generated visualization image paths
367
+ metrics_list: list of metrics included in the analysis
368
+
369
+ Returns:
370
+ str: path to saved HTML report
371
+ """
372
+ # Create HTML content
373
+ html_content = f"""
374
+ <!DOCTYPE html>
375
+ <html>
376
+ <head>
377
+ <title>Image Model Evaluation Report</title>
378
+ <style>
379
+ body {{
380
+ font-family: Arial, sans-serif;
381
+ line-height: 1.6;
382
+ margin: 0;
383
+ padding: 20px;
384
+ color: #333;
385
+ }}
386
+ h1, h2, h3 {{
387
+ color: #2c3e50;
388
+ }}
389
+ .container {{
390
+ max-width: 1200px;
391
+ margin: 0 auto;
392
+ }}
393
+ table {{
394
+ border-collapse: collapse;
395
+ width: 100%;
396
+ margin-bottom: 20px;
397
+ }}
398
+ th, td {{
399
+ border: 1px solid #ddd;
400
+ padding: 8px;
401
+ text-align: left;
402
+ }}
403
+ th {{
404
+ background-color: #f2f2f2;
405
+ }}
406
+ tr:nth-child(even) {{
407
+ background-color: #f9f9f9;
408
+ }}
409
+ .visualization {{
410
+ margin: 20px 0;
411
+ text-align: center;
412
+ }}
413
+ .visualization img {{
414
+ max-width: 100%;
415
+ height: auto;
416
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
417
+ }}
418
+ .metrics-list {{
419
+ background-color: #f8f9fa;
420
+ padding: 15px;
421
+ border-radius: 5px;
422
+ margin-bottom: 20px;
423
+ }}
424
+ </style>
425
+ </head>
426
+ <body>
427
+ <div class="container">
428
+ <h1>Image Model Evaluation Report</h1>
429
+
430
+ <h2>Metrics Overview</h2>
431
+ <div class="metrics-list">
432
+ <h3>Metrics included in this analysis:</h3>
433
+ <ul>
434
+ """
435
+
436
+ # Add metrics list
437
+ for metric in metrics_list:
438
+ html_content += f" <li><strong>{metric}</strong></li>\n"
439
+
440
+ html_content += """
441
+ </ul>
442
+ </div>
443
+
444
+ <h2>Comparison Table</h2>
445
+ """
446
+
447
+ # Add comparison table
448
+ html_content += comparison_table.to_html(classes="table table-striped")
449
+
450
+ # Add visualizations
451
+ html_content += """
452
+ <h2>Visualizations</h2>
453
+ """
454
+
455
+ for title, img_path in image_paths.items():
456
+ if os.path.exists(img_path):
457
+ # Convert image to base64 for embedding
458
+ with open(img_path, "rb") as img_file:
459
+ img_data = base64.b64encode(img_file.read()).decode('utf-8')
460
+
461
+ html_content += f"""
462
+ <div class="visualization">
463
+ <h3>{title}</h3>
464
+ <img src="data:image/png;base64,{img_data}" alt="{title}">
465
+ </div>
466
+ """
467
+
468
+ # Close HTML
469
+ html_content += """
470
+ </div>
471
+ </body>
472
+ </html>
473
+ """
474
+
475
+ # Save HTML report
476
+ output_path = os.path.join(self.output_dir, "evaluation_report.html")
477
+ with open(output_path, 'w', encoding='utf-8') as f:
478
+ f.write(html_content)
479
+
480
+ return output_path
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ pillow>=9.0.0
3
+ numpy>=1.20.0
4
+ pandas>=1.3.0
5
+ matplotlib>=3.5.0
6
+ seaborn>=0.11.0
7
+ scikit-image>=0.19.0
8
+ opencv-python>=4.5.0
9
+ torch>=2.0.0
10
+ torchvision>=0.15.0
11
+ transformers>=4.30.0
12
+ clip>=0.2.0
13
+ timm>=0.6.0
14
+ openpyxl>=3.0.0
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Utility modules for the Image Evaluator tool.
3
+ """
utils/data_handling.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for data handling and export.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import csv
8
+ import pandas as pd
9
+ from datetime import datetime
10
+
11
+
12
+ def save_json(data, file_path):
13
+ """
14
+ Save data to a JSON file.
15
+
16
+ Args:
17
+ data: data to save
18
+ file_path: path to the output file
19
+
20
+ Returns:
21
+ bool: True if successful, False otherwise
22
+ """
23
+ try:
24
+ with open(file_path, 'w', encoding='utf-8') as f:
25
+ json.dump(data, f, indent=2, ensure_ascii=False)
26
+ return True
27
+ except Exception as e:
28
+ print(f"Error saving JSON: {e}")
29
+ return False
30
+
31
+
32
+ def load_json(file_path):
33
+ """
34
+ Load data from a JSON file.
35
+
36
+ Args:
37
+ file_path: path to the JSON file
38
+
39
+ Returns:
40
+ dict: loaded data, or None if an error occurred
41
+ """
42
+ try:
43
+ with open(file_path, 'r', encoding='utf-8') as f:
44
+ return json.load(f)
45
+ except Exception as e:
46
+ print(f"Error loading JSON: {e}")
47
+ return None
48
+
49
+
50
+ def save_csv(data, file_path, headers=None):
51
+ """
52
+ Save data to a CSV file.
53
+
54
+ Args:
55
+ data: list of dictionaries or list of lists
56
+ file_path: path to the output file
57
+ headers: optional list of column headers
58
+
59
+ Returns:
60
+ bool: True if successful, False otherwise
61
+ """
62
+ try:
63
+ if isinstance(data, list) and len(data) > 0:
64
+ if isinstance(data[0], dict):
65
+ # List of dictionaries
66
+ if headers is None:
67
+ headers = list(data[0].keys())
68
+
69
+ with open(file_path, 'w', newline='', encoding='utf-8') as f:
70
+ writer = csv.DictWriter(f, fieldnames=headers)
71
+ writer.writeheader()
72
+ writer.writerows(data)
73
+ else:
74
+ # List of lists
75
+ with open(file_path, 'w', newline='', encoding='utf-8') as f:
76
+ writer = csv.writer(f)
77
+ if headers:
78
+ writer.writerow(headers)
79
+ writer.writerows(data)
80
+ return True
81
+ except Exception as e:
82
+ print(f"Error saving CSV: {e}")
83
+ return False
84
+
85
+
86
+ def dataframe_to_formats(df, base_path, formats=None):
87
+ """
88
+ Export a pandas DataFrame to multiple formats.
89
+
90
+ Args:
91
+ df: pandas DataFrame
92
+ base_path: base path for output files (without extension)
93
+ formats: list of formats to export to ('csv', 'excel', 'html', 'json')
94
+
95
+ Returns:
96
+ dict: dictionary with format names as keys and file paths as values
97
+ """
98
+ if formats is None:
99
+ formats = ['csv', 'excel', 'html']
100
+
101
+ result = {}
102
+
103
+ try:
104
+ for fmt in formats:
105
+ if fmt == 'csv':
106
+ file_path = f"{base_path}.csv"
107
+ df.to_csv(file_path)
108
+ result['csv'] = file_path
109
+ elif fmt == 'excel':
110
+ file_path = f"{base_path}.xlsx"
111
+ df.to_excel(file_path)
112
+ result['excel'] = file_path
113
+ elif fmt == 'html':
114
+ file_path = f"{base_path}.html"
115
+ df.to_html(file_path)
116
+ result['html'] = file_path
117
+ elif fmt == 'json':
118
+ file_path = f"{base_path}.json"
119
+ df.to_json(file_path, orient='records', indent=2)
120
+ result['json'] = file_path
121
+ except Exception as e:
122
+ print(f"Error exporting DataFrame: {e}")
123
+
124
+ return result
125
+
126
+
127
+ def generate_timestamp():
128
+ """
129
+ Generate a timestamp string for file naming.
130
+
131
+ Returns:
132
+ str: timestamp string
133
+ """
134
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
135
+
136
+
137
+ def create_results_filename(prefix="evaluation", extension=""):
138
+ """
139
+ Create a filename for results with timestamp.
140
+
141
+ Args:
142
+ prefix: prefix for the filename
143
+ extension: file extension (with or without dot)
144
+
145
+ Returns:
146
+ str: filename with timestamp
147
+ """
148
+ timestamp = generate_timestamp()
149
+
150
+ if extension:
151
+ if not extension.startswith('.'):
152
+ extension = f".{extension}"
153
+ return f"{prefix}_{timestamp}{extension}"
154
+ else:
155
+ return f"{prefix}_{timestamp}"
utils/image_processing.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for image processing and data handling.
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import tempfile
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+
12
+ def create_thumbnail(image_path, max_size=(200, 200)):
13
+ """
14
+ Create a thumbnail of an image.
15
+
16
+ Args:
17
+ image_path: path to the image file
18
+ max_size: maximum size of the thumbnail as (width, height)
19
+
20
+ Returns:
21
+ PIL.Image: thumbnail image
22
+ """
23
+ try:
24
+ image = Image.open(image_path)
25
+ image.thumbnail(max_size)
26
+ return image
27
+ except Exception as e:
28
+ print(f"Error creating thumbnail for {image_path}: {e}")
29
+ return None
30
+
31
+
32
+ def create_temp_directory():
33
+ """
34
+ Create a temporary directory for storing intermediate files.
35
+
36
+ Returns:
37
+ str: path to the temporary directory
38
+ """
39
+ temp_dir = tempfile.mkdtemp(prefix="image_evaluator_")
40
+ return temp_dir
41
+
42
+
43
+ def cleanup_temp_directory(temp_dir):
44
+ """
45
+ Clean up a temporary directory.
46
+
47
+ Args:
48
+ temp_dir: path to the temporary directory
49
+ """
50
+ if os.path.exists(temp_dir):
51
+ shutil.rmtree(temp_dir)
52
+
53
+
54
+ def ensure_directory(directory):
55
+ """
56
+ Ensure that a directory exists, creating it if necessary.
57
+
58
+ Args:
59
+ directory: path to the directory
60
+
61
+ Returns:
62
+ str: path to the directory
63
+ """
64
+ os.makedirs(directory, exist_ok=True)
65
+ return directory
66
+
67
+
68
+ def is_valid_image(file_path):
69
+ """
70
+ Check if a file is a valid image.
71
+
72
+ Args:
73
+ file_path: path to the file
74
+
75
+ Returns:
76
+ bool: True if the file is a valid image, False otherwise
77
+ """
78
+ try:
79
+ with Image.open(file_path) as img:
80
+ img.verify()
81
+ return True
82
+ except:
83
+ return False
84
+
85
+
86
+ def convert_to_rgb(image_path):
87
+ """
88
+ Convert an image to RGB mode if necessary.
89
+
90
+ Args:
91
+ image_path: path to the image file
92
+
93
+ Returns:
94
+ numpy.ndarray: RGB image array
95
+ """
96
+ try:
97
+ image = Image.open(image_path)
98
+ if image.mode != 'RGB':
99
+ image = image.convert('RGB')
100
+ return np.array(image)
101
+ except Exception as e:
102
+ print(f"Error converting image to RGB: {e}")
103
+ return None