Spaces:
Sleeping
Sleeping
Upload 11 files
Browse files- app.py +409 -0
- modules/__init__.py +3 -0
- modules/aesthetic_metrics.py +252 -0
- modules/aggregator.py +215 -0
- modules/metadata_extractor.py +168 -0
- modules/technical_metrics.py +189 -0
- modules/visualizer.py +480 -0
- requirements.txt +14 -0
- utils/__init__.py +3 -0
- utils/data_handling.py +155 -0
- utils/image_processing.py +103 -0
app.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main application file for the Image Evaluator tool.
|
| 3 |
+
This module integrates all components and provides a Gradio interface.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import glob
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import json
|
| 14 |
+
import tempfile
|
| 15 |
+
import shutil
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
|
| 18 |
+
# Import custom modules
|
| 19 |
+
from modules.metadata_extractor import MetadataExtractor
|
| 20 |
+
from modules.technical_metrics import TechnicalMetrics
|
| 21 |
+
from modules.aesthetic_metrics import AestheticMetrics
|
| 22 |
+
from modules.aggregator import ResultsAggregator
|
| 23 |
+
from modules.visualizer import Visualizer
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ImageEvaluator:
|
| 27 |
+
"""Main class for the Image Evaluator application."""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
"""Initialize the Image Evaluator."""
|
| 31 |
+
self.results_dir = os.path.join(os.getcwd(), "results")
|
| 32 |
+
os.makedirs(self.results_dir, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
# Initialize components
|
| 35 |
+
self.metadata_extractor = MetadataExtractor()
|
| 36 |
+
self.technical_metrics = TechnicalMetrics()
|
| 37 |
+
self.aesthetic_metrics = AestheticMetrics()
|
| 38 |
+
self.aggregator = ResultsAggregator()
|
| 39 |
+
self.visualizer = Visualizer(self.results_dir)
|
| 40 |
+
|
| 41 |
+
# Storage for results
|
| 42 |
+
self.evaluation_results = {}
|
| 43 |
+
self.metadata_cache = {}
|
| 44 |
+
self.current_comparison = None
|
| 45 |
+
|
| 46 |
+
def process_images(self, image_files, progress=None):
|
| 47 |
+
"""
|
| 48 |
+
Process a list of image files and extract metadata.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
image_files: list of image file paths
|
| 52 |
+
progress: optional gradio Progress object
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
tuple: (metadata_by_model, metadata_by_prompt)
|
| 56 |
+
"""
|
| 57 |
+
metadata_list = []
|
| 58 |
+
|
| 59 |
+
total_files = len(image_files)
|
| 60 |
+
for i, img_path in enumerate(image_files):
|
| 61 |
+
if progress:
|
| 62 |
+
progress(i / total_files, f"Processing image {i+1}/{total_files}")
|
| 63 |
+
|
| 64 |
+
# Extract metadata
|
| 65 |
+
metadata = self.metadata_extractor.extract_metadata(img_path)
|
| 66 |
+
metadata_list.append((img_path, metadata))
|
| 67 |
+
|
| 68 |
+
# Cache metadata
|
| 69 |
+
self.metadata_cache[img_path] = metadata
|
| 70 |
+
|
| 71 |
+
# Group by model and prompt
|
| 72 |
+
metadata_by_model = self.metadata_extractor.group_images_by_model(metadata_list)
|
| 73 |
+
metadata_by_prompt = self.metadata_extractor.group_images_by_prompt(metadata_list)
|
| 74 |
+
|
| 75 |
+
return metadata_by_model, metadata_by_prompt
|
| 76 |
+
|
| 77 |
+
def evaluate_images(self, image_files, progress=None):
|
| 78 |
+
"""
|
| 79 |
+
Evaluate a list of image files using all metrics.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
image_files: list of image file paths
|
| 83 |
+
progress: optional gradio Progress object
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
dict: evaluation results by image path
|
| 87 |
+
"""
|
| 88 |
+
results = {}
|
| 89 |
+
|
| 90 |
+
total_files = len(image_files)
|
| 91 |
+
for i, img_path in enumerate(image_files):
|
| 92 |
+
if progress:
|
| 93 |
+
progress(i / total_files, f"Evaluating image {i+1}/{total_files}")
|
| 94 |
+
|
| 95 |
+
# Get metadata if available
|
| 96 |
+
metadata = self.metadata_cache.get(img_path, {})
|
| 97 |
+
prompt = metadata.get('prompt', '')
|
| 98 |
+
|
| 99 |
+
# Calculate technical metrics
|
| 100 |
+
tech_metrics = self.technical_metrics.calculate_all_metrics(img_path)
|
| 101 |
+
|
| 102 |
+
# Calculate aesthetic metrics
|
| 103 |
+
aesthetic_metrics = self.aesthetic_metrics.calculate_all_metrics(img_path, prompt)
|
| 104 |
+
|
| 105 |
+
# Combine results
|
| 106 |
+
combined_metrics = {**tech_metrics, **aesthetic_metrics}
|
| 107 |
+
|
| 108 |
+
# Store results
|
| 109 |
+
results[img_path] = combined_metrics
|
| 110 |
+
|
| 111 |
+
return results
|
| 112 |
+
|
| 113 |
+
def compare_models(self, evaluation_results, metadata_by_model):
|
| 114 |
+
"""
|
| 115 |
+
Compare different models based on evaluation results.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
evaluation_results: dictionary with image paths as keys and metrics as values
|
| 119 |
+
metadata_by_model: dictionary with model names as keys and lists of image paths as values
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
tuple: (comparison_df, visualizations)
|
| 123 |
+
"""
|
| 124 |
+
# Group results by model
|
| 125 |
+
results_by_model = {}
|
| 126 |
+
for model, image_paths in metadata_by_model.items():
|
| 127 |
+
model_results = [evaluation_results[img] for img in image_paths if img in evaluation_results]
|
| 128 |
+
results_by_model[model] = model_results
|
| 129 |
+
|
| 130 |
+
# Compare models
|
| 131 |
+
comparison = self.aggregator.compare_models(results_by_model)
|
| 132 |
+
|
| 133 |
+
# Create comparison dataframe
|
| 134 |
+
comparison_df = self.aggregator.create_comparison_dataframe(comparison)
|
| 135 |
+
|
| 136 |
+
# Store current comparison
|
| 137 |
+
self.current_comparison = comparison_df
|
| 138 |
+
|
| 139 |
+
# Create visualizations
|
| 140 |
+
visualizations = {}
|
| 141 |
+
|
| 142 |
+
# Create heatmap
|
| 143 |
+
heatmap_path = self.visualizer.plot_heatmap(comparison_df)
|
| 144 |
+
visualizations['Model Comparison Heatmap'] = heatmap_path
|
| 145 |
+
|
| 146 |
+
# Create radar chart for key metrics
|
| 147 |
+
key_metrics = ['aesthetic_score', 'sharpness', 'noise', 'contrast', 'color_harmony', 'prompt_similarity']
|
| 148 |
+
available_metrics = [m for m in key_metrics if m in comparison_df.columns]
|
| 149 |
+
if available_metrics:
|
| 150 |
+
radar_path = self.visualizer.plot_radar_chart(comparison_df, available_metrics)
|
| 151 |
+
visualizations['Model Comparison Radar Chart'] = radar_path
|
| 152 |
+
|
| 153 |
+
# Create bar charts for important metrics
|
| 154 |
+
for metric in ['overall_score', 'aesthetic_score', 'prompt_similarity']:
|
| 155 |
+
if metric in comparison_df.columns:
|
| 156 |
+
bar_path = self.visualizer.plot_metric_comparison(comparison_df, metric)
|
| 157 |
+
visualizations[f'{metric} Comparison'] = bar_path
|
| 158 |
+
|
| 159 |
+
return comparison_df, visualizations
|
| 160 |
+
|
| 161 |
+
def export_results(self, format='csv'):
|
| 162 |
+
"""
|
| 163 |
+
Export current comparison results.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
format: export format ('csv', 'excel', or 'html')
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
str: path to exported file
|
| 170 |
+
"""
|
| 171 |
+
if self.current_comparison is not None:
|
| 172 |
+
return self.visualizer.export_comparison_table(self.current_comparison, format)
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
def generate_report(self, comparison_df, visualizations):
|
| 176 |
+
"""
|
| 177 |
+
Generate a comprehensive HTML report.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
comparison_df: pandas DataFrame with comparison data
|
| 181 |
+
visualizations: dictionary of visualization paths
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
str: path to HTML report
|
| 185 |
+
"""
|
| 186 |
+
metrics_list = comparison_df.columns.tolist()
|
| 187 |
+
return self.visualizer.generate_html_report(comparison_df, visualizations, metrics_list)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Create Gradio interface
|
| 191 |
+
def create_interface():
|
| 192 |
+
"""Create and configure the Gradio interface."""
|
| 193 |
+
|
| 194 |
+
# Initialize evaluator
|
| 195 |
+
evaluator = ImageEvaluator()
|
| 196 |
+
|
| 197 |
+
# Track state
|
| 198 |
+
state = {
|
| 199 |
+
'uploaded_images': [],
|
| 200 |
+
'metadata_by_model': {},
|
| 201 |
+
'metadata_by_prompt': {},
|
| 202 |
+
'evaluation_results': {},
|
| 203 |
+
'comparison_df': None,
|
| 204 |
+
'visualizations': {},
|
| 205 |
+
'report_path': None
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
def upload_images(files, progress=gr.Progress()):
|
| 209 |
+
"""Handle image upload and processing."""
|
| 210 |
+
# Reset state
|
| 211 |
+
state['uploaded_images'] = []
|
| 212 |
+
state['metadata_by_model'] = {}
|
| 213 |
+
state['metadata_by_prompt'] = {}
|
| 214 |
+
state['evaluation_results'] = {}
|
| 215 |
+
state['comparison_df'] = None
|
| 216 |
+
state['visualizations'] = {}
|
| 217 |
+
state['report_path'] = None
|
| 218 |
+
|
| 219 |
+
# Process uploaded files
|
| 220 |
+
image_paths = [f.name for f in files]
|
| 221 |
+
state['uploaded_images'] = image_paths
|
| 222 |
+
|
| 223 |
+
# Extract metadata and group images
|
| 224 |
+
progress(0, "Extracting metadata...")
|
| 225 |
+
metadata_by_model, metadata_by_prompt = evaluator.process_images(image_paths, progress)
|
| 226 |
+
state['metadata_by_model'] = metadata_by_model
|
| 227 |
+
state['metadata_by_prompt'] = metadata_by_prompt
|
| 228 |
+
|
| 229 |
+
# Create model summary
|
| 230 |
+
model_summary = []
|
| 231 |
+
for model, images in metadata_by_model.items():
|
| 232 |
+
model_summary.append(f"- {model}: {len(images)} images")
|
| 233 |
+
|
| 234 |
+
# Create prompt summary
|
| 235 |
+
prompt_summary = []
|
| 236 |
+
for prompt, images in metadata_by_prompt.items():
|
| 237 |
+
prompt_summary.append(f"- {prompt}: {len(images)} images")
|
| 238 |
+
|
| 239 |
+
return (
|
| 240 |
+
f"Processed {len(image_paths)} images.\n\n"
|
| 241 |
+
f"Found {len(metadata_by_model)} models:\n" + "\n".join(model_summary) + "\n\n"
|
| 242 |
+
f"Found {len(metadata_by_prompt)} unique prompts."
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def evaluate_images(progress=gr.Progress()):
|
| 246 |
+
"""Evaluate all uploaded images."""
|
| 247 |
+
if not state['uploaded_images']:
|
| 248 |
+
return "No images uploaded. Please upload images first."
|
| 249 |
+
|
| 250 |
+
# Evaluate images
|
| 251 |
+
progress(0, "Evaluating images...")
|
| 252 |
+
evaluation_results = evaluator.evaluate_images(state['uploaded_images'], progress)
|
| 253 |
+
state['evaluation_results'] = evaluation_results
|
| 254 |
+
|
| 255 |
+
return f"Evaluated {len(evaluation_results)} images with all metrics."
|
| 256 |
+
|
| 257 |
+
def compare_models():
|
| 258 |
+
"""Compare models based on evaluation results."""
|
| 259 |
+
if not state['evaluation_results'] or not state['metadata_by_model']:
|
| 260 |
+
return "No evaluation results available. Please evaluate images first.", None, None
|
| 261 |
+
|
| 262 |
+
# Compare models
|
| 263 |
+
comparison_df, visualizations = evaluator.compare_models(
|
| 264 |
+
state['evaluation_results'], state['metadata_by_model']
|
| 265 |
+
)
|
| 266 |
+
state['comparison_df'] = comparison_df
|
| 267 |
+
state['visualizations'] = visualizations
|
| 268 |
+
|
| 269 |
+
# Generate report
|
| 270 |
+
report_path = evaluator.generate_report(comparison_df, visualizations)
|
| 271 |
+
state['report_path'] = report_path
|
| 272 |
+
|
| 273 |
+
# Get visualization paths
|
| 274 |
+
heatmap_path = visualizations.get('Model Comparison Heatmap')
|
| 275 |
+
radar_path = visualizations.get('Model Comparison Radar Chart')
|
| 276 |
+
overall_score_path = visualizations.get('overall_score Comparison')
|
| 277 |
+
|
| 278 |
+
# Convert DataFrame to markdown for display
|
| 279 |
+
df_markdown = comparison_df.to_markdown()
|
| 280 |
+
|
| 281 |
+
return df_markdown, heatmap_path, radar_path
|
| 282 |
+
|
| 283 |
+
def export_results(format):
|
| 284 |
+
"""Export results in the specified format."""
|
| 285 |
+
if state['comparison_df'] is None:
|
| 286 |
+
return "No comparison results available. Please compare models first."
|
| 287 |
+
|
| 288 |
+
export_path = evaluator.export_results(format)
|
| 289 |
+
if export_path:
|
| 290 |
+
return f"Results exported to {export_path}"
|
| 291 |
+
else:
|
| 292 |
+
return "Failed to export results."
|
| 293 |
+
|
| 294 |
+
def view_report():
|
| 295 |
+
"""View the generated HTML report."""
|
| 296 |
+
if state['report_path'] and os.path.exists(state['report_path']):
|
| 297 |
+
return state['report_path']
|
| 298 |
+
else:
|
| 299 |
+
return "No report available. Please compare models first."
|
| 300 |
+
|
| 301 |
+
# Create interface
|
| 302 |
+
with gr.Blocks(title="Image Model Evaluator") as interface:
|
| 303 |
+
gr.Markdown("# Image Model Evaluator")
|
| 304 |
+
gr.Markdown("Upload images generated by different AI models to compare their quality and performance.")
|
| 305 |
+
|
| 306 |
+
with gr.Tab("Upload & Process"):
|
| 307 |
+
with gr.Row():
|
| 308 |
+
with gr.Column():
|
| 309 |
+
upload_input = gr.File(
|
| 310 |
+
label="Upload Images (PNG format)",
|
| 311 |
+
file_count="multiple",
|
| 312 |
+
type="file"
|
| 313 |
+
)
|
| 314 |
+
upload_button = gr.Button("Process Uploaded Images")
|
| 315 |
+
|
| 316 |
+
with gr.Column():
|
| 317 |
+
upload_output = gr.Textbox(
|
| 318 |
+
label="Processing Results",
|
| 319 |
+
lines=10,
|
| 320 |
+
interactive=False
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
evaluate_button = gr.Button("Evaluate Images")
|
| 324 |
+
evaluate_output = gr.Textbox(
|
| 325 |
+
label="Evaluation Status",
|
| 326 |
+
lines=2,
|
| 327 |
+
interactive=False
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
with gr.Tab("Compare Models"):
|
| 331 |
+
compare_button = gr.Button("Compare Models")
|
| 332 |
+
|
| 333 |
+
with gr.Row():
|
| 334 |
+
comparison_output = gr.Markdown(
|
| 335 |
+
label="Comparison Results"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
with gr.Row():
|
| 339 |
+
with gr.Column():
|
| 340 |
+
heatmap_output = gr.Image(
|
| 341 |
+
label="Model Comparison Heatmap",
|
| 342 |
+
interactive=False
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
with gr.Column():
|
| 346 |
+
radar_output = gr.Image(
|
| 347 |
+
label="Model Comparison Radar Chart",
|
| 348 |
+
interactive=False
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
with gr.Tab("Export & Report"):
|
| 352 |
+
with gr.Row():
|
| 353 |
+
with gr.Column():
|
| 354 |
+
export_format = gr.Radio(
|
| 355 |
+
label="Export Format",
|
| 356 |
+
choices=["csv", "excel", "html"],
|
| 357 |
+
value="csv"
|
| 358 |
+
)
|
| 359 |
+
export_button = gr.Button("Export Results")
|
| 360 |
+
export_output = gr.Textbox(
|
| 361 |
+
label="Export Status",
|
| 362 |
+
lines=2,
|
| 363 |
+
interactive=False
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
with gr.Column():
|
| 367 |
+
report_button = gr.Button("View Full Report")
|
| 368 |
+
report_output = gr.HTML(
|
| 369 |
+
label="Full Report"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Set up event handlers
|
| 373 |
+
upload_button.click(
|
| 374 |
+
upload_images,
|
| 375 |
+
inputs=[upload_input],
|
| 376 |
+
outputs=[upload_output]
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
evaluate_button.click(
|
| 380 |
+
evaluate_images,
|
| 381 |
+
inputs=[],
|
| 382 |
+
outputs=[evaluate_output]
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
compare_button.click(
|
| 386 |
+
compare_models,
|
| 387 |
+
inputs=[],
|
| 388 |
+
outputs=[comparison_output, heatmap_output, radar_output]
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
export_button.click(
|
| 392 |
+
export_results,
|
| 393 |
+
inputs=[export_format],
|
| 394 |
+
outputs=[export_output]
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
report_button.click(
|
| 398 |
+
view_report,
|
| 399 |
+
inputs=[],
|
| 400 |
+
outputs=[report_output]
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
return interface
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
# Launch the application
|
| 407 |
+
if __name__ == "__main__":
|
| 408 |
+
interface = create_interface()
|
| 409 |
+
interface.launch(share=True)
|
modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module initialization file for the Image Evaluator tool.
|
| 3 |
+
"""
|
modules/aesthetic_metrics.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Aesthetic metrics for image quality assessment using AI models.
|
| 3 |
+
These metrics evaluate subjective aspects of images like aesthetic appeal, composition, etc.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
| 10 |
+
import clip
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AestheticMetrics:
|
| 15 |
+
"""Class for computing aesthetic image quality metrics using AI models."""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
"""Initialize models for aesthetic evaluation."""
|
| 19 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
self._initialize_models()
|
| 21 |
+
|
| 22 |
+
def _initialize_models(self):
|
| 23 |
+
"""Initialize all required models."""
|
| 24 |
+
# Initialize CLIP model for text-image similarity
|
| 25 |
+
try:
|
| 26 |
+
self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
|
| 27 |
+
self.clip_loaded = True
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"Warning: Could not load CLIP model: {e}")
|
| 30 |
+
self.clip_loaded = False
|
| 31 |
+
|
| 32 |
+
# Initialize aesthetic predictor model (LAION Aesthetic Predictor v2)
|
| 33 |
+
try:
|
| 34 |
+
self.aesthetic_model_name = "cafeai/cafe_aesthetic"
|
| 35 |
+
self.aesthetic_extractor = AutoFeatureExtractor.from_pretrained(self.aesthetic_model_name)
|
| 36 |
+
self.aesthetic_model = AutoModelForImageClassification.from_pretrained(self.aesthetic_model_name)
|
| 37 |
+
self.aesthetic_model.to(self.device)
|
| 38 |
+
self.aesthetic_loaded = True
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Warning: Could not load aesthetic model: {e}")
|
| 41 |
+
self.aesthetic_loaded = False
|
| 42 |
+
|
| 43 |
+
# Initialize transforms for preprocessing
|
| 44 |
+
self.transform = transforms.Compose([
|
| 45 |
+
transforms.Resize(256),
|
| 46 |
+
transforms.CenterCrop(224),
|
| 47 |
+
transforms.ToTensor(),
|
| 48 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
def calculate_aesthetic_score(self, image_path):
|
| 52 |
+
"""
|
| 53 |
+
Calculate aesthetic score using a pre-trained model.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
image_path: path to the image file
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
float: aesthetic score between 0 and 10
|
| 60 |
+
"""
|
| 61 |
+
if not self.aesthetic_loaded:
|
| 62 |
+
return 5.0 # Default middle score if model not loaded
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
image = Image.open(image_path).convert('RGB')
|
| 66 |
+
inputs = self.aesthetic_extractor(images=image, return_tensors="pt").to(self.device)
|
| 67 |
+
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
outputs = self.aesthetic_model(**inputs)
|
| 70 |
+
|
| 71 |
+
# Get predicted class probabilities
|
| 72 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
|
| 73 |
+
|
| 74 |
+
# Calculate weighted score (0-10 scale)
|
| 75 |
+
score_weights = torch.tensor([i for i in range(10)]).to(self.device).float()
|
| 76 |
+
aesthetic_score = torch.sum(probs * score_weights).item()
|
| 77 |
+
|
| 78 |
+
return aesthetic_score
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Error calculating aesthetic score: {e}")
|
| 81 |
+
return 5.0
|
| 82 |
+
|
| 83 |
+
def calculate_composition_score(self, image_path):
|
| 84 |
+
"""
|
| 85 |
+
Estimate composition quality using rule of thirds and symmetry analysis.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
image_path: path to the image file
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
float: composition score between 0 and 10
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
# Load image
|
| 95 |
+
image = Image.open(image_path).convert('RGB')
|
| 96 |
+
img_array = np.array(image)
|
| 97 |
+
|
| 98 |
+
# Calculate rule of thirds score
|
| 99 |
+
h, w = img_array.shape[:2]
|
| 100 |
+
third_h, third_w = h // 3, w // 3
|
| 101 |
+
|
| 102 |
+
# Define rule of thirds points
|
| 103 |
+
thirds_points = [
|
| 104 |
+
(third_w, third_h), (2*third_w, third_h),
|
| 105 |
+
(third_w, 2*third_h), (2*third_w, 2*third_h)
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
# Calculate edge detection to find important elements
|
| 109 |
+
gray = np.mean(img_array, axis=2).astype(np.uint8)
|
| 110 |
+
edges = np.abs(np.diff(gray, axis=0, append=0)) + np.abs(np.diff(gray, axis=1, append=0))
|
| 111 |
+
|
| 112 |
+
# Calculate score based on edge concentration near thirds points
|
| 113 |
+
thirds_score = 0
|
| 114 |
+
for px, py in thirds_points:
|
| 115 |
+
# Get region around thirds point
|
| 116 |
+
region = edges[max(0, py-50):min(h, py+50), max(0, px-50):min(w, px+50)]
|
| 117 |
+
thirds_score += np.mean(region)
|
| 118 |
+
|
| 119 |
+
# Normalize score
|
| 120 |
+
thirds_score = min(10, thirds_score / 100)
|
| 121 |
+
|
| 122 |
+
# Calculate symmetry score
|
| 123 |
+
flipped = np.fliplr(img_array)
|
| 124 |
+
symmetry_diff = np.mean(np.abs(img_array.astype(float) - flipped.astype(float)))
|
| 125 |
+
symmetry_score = 10 * (1 - symmetry_diff / 255)
|
| 126 |
+
|
| 127 |
+
# Combine scores (weighted average)
|
| 128 |
+
composition_score = 0.7 * thirds_score + 0.3 * symmetry_score
|
| 129 |
+
|
| 130 |
+
return min(10, max(0, composition_score))
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"Error calculating composition score: {e}")
|
| 133 |
+
return 5.0
|
| 134 |
+
|
| 135 |
+
def calculate_color_harmony(self, image_path):
|
| 136 |
+
"""
|
| 137 |
+
Calculate color harmony score based on color theory.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
image_path: path to the image file
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
float: color harmony score between 0 and 10
|
| 144 |
+
"""
|
| 145 |
+
try:
|
| 146 |
+
# Load image
|
| 147 |
+
image = Image.open(image_path).convert('RGB')
|
| 148 |
+
img_array = np.array(image)
|
| 149 |
+
|
| 150 |
+
# Convert to HSV for better color analysis
|
| 151 |
+
hsv = np.array(image.convert('HSV'))
|
| 152 |
+
|
| 153 |
+
# Extract hue channel and create histogram
|
| 154 |
+
hue = hsv[:,:,0].flatten()
|
| 155 |
+
hist, _ = np.histogram(hue, bins=36, range=(0, 255))
|
| 156 |
+
hist = hist / np.sum(hist)
|
| 157 |
+
|
| 158 |
+
# Calculate entropy of hue distribution
|
| 159 |
+
entropy = -np.sum(hist * np.log2(hist + 1e-10))
|
| 160 |
+
|
| 161 |
+
# Calculate complementary color usage
|
| 162 |
+
complementary_score = 0
|
| 163 |
+
for i in range(18):
|
| 164 |
+
complementary_i = (i + 18) % 36
|
| 165 |
+
complementary_score += min(hist[i], hist[complementary_i])
|
| 166 |
+
|
| 167 |
+
# Calculate analogous color usage
|
| 168 |
+
analogous_score = 0
|
| 169 |
+
for i in range(36):
|
| 170 |
+
analogous_i1 = (i + 1) % 36
|
| 171 |
+
analogous_i2 = (i + 35) % 36
|
| 172 |
+
analogous_score += min(hist[i], max(hist[analogous_i1], hist[analogous_i2]))
|
| 173 |
+
|
| 174 |
+
# Calculate saturation variance as a measure of color interest
|
| 175 |
+
saturation = hsv[:,:,1].flatten()
|
| 176 |
+
saturation_variance = np.var(saturation)
|
| 177 |
+
|
| 178 |
+
# Combine metrics into final score
|
| 179 |
+
harmony_score = (
|
| 180 |
+
3 * (1 - min(1, entropy/5)) + # Lower entropy is better for harmony
|
| 181 |
+
3 * complementary_score + # Complementary colors
|
| 182 |
+
2 * analogous_score + # Analogous colors
|
| 183 |
+
2 * min(1, saturation_variance/2000) # Saturation variance
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return min(10, max(0, harmony_score))
|
| 187 |
+
except Exception as e:
|
| 188 |
+
print(f"Error calculating color harmony: {e}")
|
| 189 |
+
return 5.0
|
| 190 |
+
|
| 191 |
+
def calculate_prompt_similarity(self, image_path, prompt):
|
| 192 |
+
"""
|
| 193 |
+
Calculate similarity between image and text prompt using CLIP.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
image_path: path to the image file
|
| 197 |
+
prompt: text prompt used to generate the image
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
float: similarity score between 0 and 10
|
| 201 |
+
"""
|
| 202 |
+
if not self.clip_loaded or not prompt:
|
| 203 |
+
return 5.0 # Default middle score if model not loaded or no prompt
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
# Load and preprocess image
|
| 207 |
+
image = Image.open(image_path).convert('RGB')
|
| 208 |
+
image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device)
|
| 209 |
+
|
| 210 |
+
# Process text
|
| 211 |
+
text_input = clip.tokenize([prompt]).to(self.device)
|
| 212 |
+
|
| 213 |
+
# Calculate similarity
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
image_features = self.clip_model.encode_image(image_input)
|
| 216 |
+
text_features = self.clip_model.encode_text(text_input)
|
| 217 |
+
|
| 218 |
+
# Normalize features
|
| 219 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 220 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 221 |
+
|
| 222 |
+
# Calculate similarity
|
| 223 |
+
similarity = (100.0 * image_features @ text_features.T).item()
|
| 224 |
+
|
| 225 |
+
# Convert to 0-10 scale
|
| 226 |
+
return min(10, max(0, similarity / 10))
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(f"Error calculating prompt similarity: {e}")
|
| 229 |
+
return 5.0
|
| 230 |
+
|
| 231 |
+
def calculate_all_metrics(self, image_path, prompt=None):
|
| 232 |
+
"""
|
| 233 |
+
Calculate all aesthetic metrics for an image.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
image_path: path to the image file
|
| 237 |
+
prompt: optional text prompt used to generate the image
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
dict: dictionary with all metric scores
|
| 241 |
+
"""
|
| 242 |
+
metrics = {
|
| 243 |
+
'aesthetic_score': self.calculate_aesthetic_score(image_path),
|
| 244 |
+
'composition_score': self.calculate_composition_score(image_path),
|
| 245 |
+
'color_harmony': self.calculate_color_harmony(image_path),
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
# Add prompt similarity if prompt is provided
|
| 249 |
+
if prompt:
|
| 250 |
+
metrics['prompt_similarity'] = self.calculate_prompt_similarity(image_path, prompt)
|
| 251 |
+
|
| 252 |
+
return metrics
|
modules/aggregator.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module for aggregating results from different evaluation metrics.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ResultsAggregator:
|
| 11 |
+
"""Class for aggregating and analyzing image evaluation results."""
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
"""Initialize the aggregator."""
|
| 15 |
+
# Weights for different metric categories
|
| 16 |
+
self.default_weights = {
|
| 17 |
+
# Technical metrics
|
| 18 |
+
'sharpness': 1.0,
|
| 19 |
+
'noise': 1.0,
|
| 20 |
+
'contrast': 1.0,
|
| 21 |
+
'saturation': 1.0,
|
| 22 |
+
'entropy': 1.0,
|
| 23 |
+
'compression_artifacts': 1.0,
|
| 24 |
+
'dynamic_range': 1.0,
|
| 25 |
+
|
| 26 |
+
# Aesthetic metrics
|
| 27 |
+
'aesthetic_score': 1.5,
|
| 28 |
+
'composition_score': 1.2,
|
| 29 |
+
'color_harmony': 1.2,
|
| 30 |
+
|
| 31 |
+
# Prompt metrics
|
| 32 |
+
'prompt_similarity': 2.0,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
# Metrics where lower is better
|
| 36 |
+
self.inverse_metrics = ['noise', 'compression_artifacts']
|
| 37 |
+
|
| 38 |
+
def normalize_metric(self, values, metric_name):
|
| 39 |
+
"""
|
| 40 |
+
Normalize metric values to 0-10 scale.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
values: list of metric values
|
| 44 |
+
metric_name: name of the metric
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
list: normalized values
|
| 48 |
+
"""
|
| 49 |
+
if not values:
|
| 50 |
+
return []
|
| 51 |
+
|
| 52 |
+
# For metrics where lower is better, invert the values
|
| 53 |
+
if metric_name in self.inverse_metrics:
|
| 54 |
+
values = [max(values) - v + min(values) for v in values]
|
| 55 |
+
|
| 56 |
+
# Normalize to 0-10 scale
|
| 57 |
+
min_val = min(values)
|
| 58 |
+
max_val = max(values)
|
| 59 |
+
|
| 60 |
+
if max_val == min_val:
|
| 61 |
+
return [5.0] * len(values) # Default to middle value if all values are the same
|
| 62 |
+
|
| 63 |
+
return [10 * (v - min_val) / (max_val - min_val) for v in values]
|
| 64 |
+
|
| 65 |
+
def aggregate_model_results(self, model_results, custom_weights=None):
|
| 66 |
+
"""
|
| 67 |
+
Aggregate results for a single model across multiple images.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
model_results: list of metric dictionaries for images from the same model
|
| 71 |
+
custom_weights: optional dictionary of custom weights for metrics
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
dict: aggregated metrics
|
| 75 |
+
"""
|
| 76 |
+
if not model_results:
|
| 77 |
+
return {}
|
| 78 |
+
|
| 79 |
+
# Use default weights if custom weights not provided
|
| 80 |
+
weights = custom_weights if custom_weights else self.default_weights
|
| 81 |
+
|
| 82 |
+
# Initialize aggregated results
|
| 83 |
+
aggregated = {}
|
| 84 |
+
|
| 85 |
+
# Collect all metrics
|
| 86 |
+
all_metrics = set()
|
| 87 |
+
for result in model_results:
|
| 88 |
+
all_metrics.update(result.keys())
|
| 89 |
+
|
| 90 |
+
# Aggregate each metric
|
| 91 |
+
for metric in all_metrics:
|
| 92 |
+
# Skip non-numeric metrics
|
| 93 |
+
values = [result.get(metric) for result in model_results if metric in result
|
| 94 |
+
and isinstance(result[metric], (int, float))]
|
| 95 |
+
|
| 96 |
+
if values:
|
| 97 |
+
aggregated[metric] = {
|
| 98 |
+
'mean': np.mean(values),
|
| 99 |
+
'median': np.median(values),
|
| 100 |
+
'std': np.std(values),
|
| 101 |
+
'min': np.min(values),
|
| 102 |
+
'max': np.max(values),
|
| 103 |
+
'count': len(values)
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Calculate overall score
|
| 107 |
+
score_components = []
|
| 108 |
+
weight_sum = 0
|
| 109 |
+
|
| 110 |
+
for metric, stats in aggregated.items():
|
| 111 |
+
if metric in weights:
|
| 112 |
+
# Normalize the mean value to 0-10 scale
|
| 113 |
+
normalized_value = stats['mean']
|
| 114 |
+
if metric in self.inverse_metrics:
|
| 115 |
+
# For metrics where lower is better, invert the scale
|
| 116 |
+
normalized_value = 10 - normalized_value
|
| 117 |
+
|
| 118 |
+
# Apply weight
|
| 119 |
+
weight = weights[metric]
|
| 120 |
+
score_components.append(normalized_value * weight)
|
| 121 |
+
weight_sum += weight
|
| 122 |
+
|
| 123 |
+
# Calculate weighted average
|
| 124 |
+
if weight_sum > 0:
|
| 125 |
+
aggregated['overall_score'] = sum(score_components) / weight_sum
|
| 126 |
+
else:
|
| 127 |
+
aggregated['overall_score'] = 5.0 # Default middle score
|
| 128 |
+
|
| 129 |
+
return aggregated
|
| 130 |
+
|
| 131 |
+
def compare_models(self, model_results_dict, custom_weights=None):
|
| 132 |
+
"""
|
| 133 |
+
Compare results across different models.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
model_results_dict: dictionary with model names as keys and lists of results as values
|
| 137 |
+
custom_weights: optional dictionary of custom weights for metrics
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
dict: comparison results
|
| 141 |
+
"""
|
| 142 |
+
# Aggregate results for each model
|
| 143 |
+
aggregated_results = {}
|
| 144 |
+
for model_name, results in model_results_dict.items():
|
| 145 |
+
aggregated_results[model_name] = self.aggregate_model_results(results, custom_weights)
|
| 146 |
+
|
| 147 |
+
# Extract key metrics for comparison
|
| 148 |
+
comparison = {}
|
| 149 |
+
for model_name, agg_results in aggregated_results.items():
|
| 150 |
+
model_comparison = {
|
| 151 |
+
'overall_score': agg_results.get('overall_score', 5.0)
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
# Add mean values of all metrics
|
| 155 |
+
for metric, stats in agg_results.items():
|
| 156 |
+
if metric != 'overall_score' and isinstance(stats, dict) and 'mean' in stats:
|
| 157 |
+
model_comparison[f"{metric}"] = stats['mean']
|
| 158 |
+
|
| 159 |
+
comparison[model_name] = model_comparison
|
| 160 |
+
|
| 161 |
+
return comparison
|
| 162 |
+
|
| 163 |
+
def analyze_by_prompt(self, results_by_prompt, custom_weights=None):
|
| 164 |
+
"""
|
| 165 |
+
Analyze results grouped by prompt.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
results_by_prompt: dictionary with prompts as keys and dictionaries of model results as values
|
| 169 |
+
custom_weights: optional dictionary of custom weights for metrics
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
dict: analysis results by prompt
|
| 173 |
+
"""
|
| 174 |
+
prompt_analysis = {}
|
| 175 |
+
|
| 176 |
+
for prompt, model_results in results_by_prompt.items():
|
| 177 |
+
# Compare models for this prompt
|
| 178 |
+
prompt_comparison = self.compare_models(model_results, custom_weights)
|
| 179 |
+
|
| 180 |
+
# Find best model for this prompt
|
| 181 |
+
best_model = None
|
| 182 |
+
best_score = -1
|
| 183 |
+
|
| 184 |
+
for model, metrics in prompt_comparison.items():
|
| 185 |
+
score = metrics.get('overall_score', 0)
|
| 186 |
+
if score > best_score:
|
| 187 |
+
best_score = score
|
| 188 |
+
best_model = model
|
| 189 |
+
|
| 190 |
+
prompt_analysis[prompt] = {
|
| 191 |
+
'model_comparison': prompt_comparison,
|
| 192 |
+
'best_model': best_model,
|
| 193 |
+
'best_score': best_score
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
return prompt_analysis
|
| 197 |
+
|
| 198 |
+
def create_comparison_dataframe(self, comparison_results):
|
| 199 |
+
"""
|
| 200 |
+
Create a pandas DataFrame from comparison results.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
comparison_results: dictionary with model names as keys and metric dictionaries as values
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
pandas.DataFrame: comparison table
|
| 207 |
+
"""
|
| 208 |
+
# Convert to DataFrame
|
| 209 |
+
df = pd.DataFrame.from_dict(comparison_results, orient='index')
|
| 210 |
+
|
| 211 |
+
# Sort by overall score
|
| 212 |
+
if 'overall_score' in df.columns:
|
| 213 |
+
df = df.sort_values('overall_score', ascending=False)
|
| 214 |
+
|
| 215 |
+
return df
|
modules/metadata_extractor.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module for extracting metadata from image files, particularly focusing on
|
| 3 |
+
Stable Diffusion metadata from PNG files.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import io
|
| 7 |
+
from PIL import Image, PngImagePlugin
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MetadataExtractor:
|
| 12 |
+
"""Class for extracting and parsing metadata from images."""
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def extract_metadata(image_path):
|
| 16 |
+
"""
|
| 17 |
+
Extract metadata from an image file.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
image_path: path to the image file
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
dict: dictionary with extracted metadata
|
| 24 |
+
"""
|
| 25 |
+
try:
|
| 26 |
+
# Open image with PIL
|
| 27 |
+
image = Image.open(image_path)
|
| 28 |
+
|
| 29 |
+
# Extract metadata from PNG info
|
| 30 |
+
metadata_text = image.info.get("parameters", "")
|
| 31 |
+
|
| 32 |
+
# Parse the metadata
|
| 33 |
+
parsed_metadata = MetadataExtractor.parse_metadata(metadata_text)
|
| 34 |
+
|
| 35 |
+
# Add basic image info
|
| 36 |
+
parsed_metadata.update({
|
| 37 |
+
'width': image.width,
|
| 38 |
+
'height': image.height,
|
| 39 |
+
'format': image.format,
|
| 40 |
+
'mode': image.mode,
|
| 41 |
+
})
|
| 42 |
+
|
| 43 |
+
return parsed_metadata
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"Error extracting metadata from {image_path}: {e}")
|
| 46 |
+
return {'error': str(e)}
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def parse_metadata(metadata_text):
|
| 50 |
+
"""
|
| 51 |
+
Parse Stable Diffusion metadata text into structured data.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
metadata_text: raw metadata text from image
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
dict: structured metadata
|
| 58 |
+
"""
|
| 59 |
+
if not metadata_text:
|
| 60 |
+
return {'raw_text': ''}
|
| 61 |
+
|
| 62 |
+
result = {'raw_text': metadata_text}
|
| 63 |
+
|
| 64 |
+
# Extract prompt
|
| 65 |
+
prompt_end = metadata_text.find("Negative prompt:")
|
| 66 |
+
if prompt_end > 0:
|
| 67 |
+
result['prompt'] = metadata_text[:prompt_end].strip()
|
| 68 |
+
negative_prompt_end = metadata_text.find("\n", prompt_end)
|
| 69 |
+
if negative_prompt_end > 0:
|
| 70 |
+
result['negative_prompt'] = metadata_text[prompt_end + len("Negative prompt:"):negative_prompt_end].strip()
|
| 71 |
+
else:
|
| 72 |
+
result['prompt'] = metadata_text.strip()
|
| 73 |
+
|
| 74 |
+
# Extract model name
|
| 75 |
+
model_match = re.search(r'Model: ([^,\n]+)', metadata_text)
|
| 76 |
+
if model_match:
|
| 77 |
+
result['model'] = model_match.group(1).strip()
|
| 78 |
+
|
| 79 |
+
# Extract other parameters
|
| 80 |
+
params = {
|
| 81 |
+
'steps': r'Steps: (\d+)',
|
| 82 |
+
'sampler': r'Sampler: ([^,\n]+)',
|
| 83 |
+
'cfg_scale': r'CFG scale: ([^,\n]+)',
|
| 84 |
+
'seed': r'Seed: ([^,\n]+)',
|
| 85 |
+
'size': r'Size: ([^,\n]+)',
|
| 86 |
+
'model_hash': r'Model hash: ([^,\n]+)',
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
for key, pattern in params.items():
|
| 90 |
+
match = re.search(pattern, metadata_text)
|
| 91 |
+
if match:
|
| 92 |
+
result[key] = match.group(1).strip()
|
| 93 |
+
|
| 94 |
+
return result
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def group_images_by_model(metadata_list):
|
| 98 |
+
"""
|
| 99 |
+
Group images by model name.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
metadata_list: list of (image_path, metadata) tuples
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
dict: dictionary with model names as keys and lists of image paths as values
|
| 106 |
+
"""
|
| 107 |
+
result = {}
|
| 108 |
+
|
| 109 |
+
for image_path, metadata in metadata_list:
|
| 110 |
+
model = metadata.get('model', 'unknown')
|
| 111 |
+
if model not in result:
|
| 112 |
+
result[model] = []
|
| 113 |
+
result[model].append(image_path)
|
| 114 |
+
|
| 115 |
+
return result
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def group_images_by_prompt(metadata_list):
|
| 119 |
+
"""
|
| 120 |
+
Group images by prompt.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
metadata_list: list of (image_path, metadata) tuples
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
dict: dictionary with prompts as keys and lists of image paths as values
|
| 127 |
+
"""
|
| 128 |
+
result = {}
|
| 129 |
+
|
| 130 |
+
for image_path, metadata in metadata_list:
|
| 131 |
+
prompt = metadata.get('prompt', 'unknown')
|
| 132 |
+
# Use first 50 chars as key to avoid extremely long keys
|
| 133 |
+
prompt_key = prompt[:50] + ('...' if len(prompt) > 50 else '')
|
| 134 |
+
if prompt_key not in result:
|
| 135 |
+
result[prompt_key] = []
|
| 136 |
+
result[prompt_key].append((image_path, metadata.get('model', 'unknown')))
|
| 137 |
+
|
| 138 |
+
return result
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def update_metadata(image_path, new_metadata, output_path=None):
|
| 142 |
+
"""
|
| 143 |
+
Update metadata in an image file.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
image_path: path to the input image file
|
| 147 |
+
new_metadata: new metadata text to write
|
| 148 |
+
output_path: path to save the updated image (if None, overwrites input)
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
bool: True if successful, False otherwise
|
| 152 |
+
"""
|
| 153 |
+
try:
|
| 154 |
+
# Open image with PIL
|
| 155 |
+
image = Image.open(image_path)
|
| 156 |
+
|
| 157 |
+
# Create a PngInfo object to store metadata
|
| 158 |
+
pnginfo = PngImagePlugin.PngInfo()
|
| 159 |
+
pnginfo.add_text("parameters", new_metadata)
|
| 160 |
+
|
| 161 |
+
# Save the image with the updated metadata
|
| 162 |
+
save_path = output_path if output_path else image_path
|
| 163 |
+
image.save(save_path, format="PNG", pnginfo=pnginfo)
|
| 164 |
+
|
| 165 |
+
return True
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"Error updating metadata: {e}")
|
| 168 |
+
return False
|
modules/technical_metrics.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Technical metrics for image quality assessment without using AI models.
|
| 3 |
+
These metrics evaluate basic technical aspects of images like sharpness, noise, etc.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
from skimage.metrics import structural_similarity as ssim
|
| 9 |
+
from skimage.measure import shannon_entropy
|
| 10 |
+
from PIL import Image, ImageStat
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TechnicalMetrics:
|
| 14 |
+
"""Class for computing technical image quality metrics."""
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def calculate_sharpness(image_array):
|
| 18 |
+
"""
|
| 19 |
+
Calculate image sharpness using Laplacian variance.
|
| 20 |
+
Higher values indicate sharper images.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
image_array: numpy array of the image
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
float: sharpness score
|
| 27 |
+
"""
|
| 28 |
+
if len(image_array.shape) == 3:
|
| 29 |
+
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
|
| 30 |
+
else:
|
| 31 |
+
gray = image_array
|
| 32 |
+
|
| 33 |
+
# Calculate variance of Laplacian
|
| 34 |
+
return cv2.Laplacian(gray, cv2.CV_64F).var()
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def calculate_noise(image_array):
|
| 38 |
+
"""
|
| 39 |
+
Estimate image noise level.
|
| 40 |
+
Lower values indicate less noisy images.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
image_array: numpy array of the image
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
float: noise level
|
| 47 |
+
"""
|
| 48 |
+
if len(image_array.shape) == 3:
|
| 49 |
+
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
|
| 50 |
+
else:
|
| 51 |
+
gray = image_array
|
| 52 |
+
|
| 53 |
+
# Estimate noise using median filter difference
|
| 54 |
+
denoised = cv2.medianBlur(gray, 5)
|
| 55 |
+
diff = cv2.absdiff(gray, denoised)
|
| 56 |
+
return np.mean(diff)
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def calculate_contrast(image_array):
|
| 60 |
+
"""
|
| 61 |
+
Calculate image contrast.
|
| 62 |
+
Higher values indicate higher contrast.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
image_array: numpy array of the image
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
float: contrast score
|
| 69 |
+
"""
|
| 70 |
+
if len(image_array.shape) == 3:
|
| 71 |
+
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
|
| 72 |
+
else:
|
| 73 |
+
gray = image_array
|
| 74 |
+
|
| 75 |
+
# Calculate standard deviation as a measure of contrast
|
| 76 |
+
return np.std(gray)
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def calculate_saturation(image_array):
|
| 80 |
+
"""
|
| 81 |
+
Calculate color saturation.
|
| 82 |
+
Higher values indicate more saturated colors.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
image_array: numpy array of the image
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
float: saturation score
|
| 89 |
+
"""
|
| 90 |
+
if len(image_array.shape) != 3:
|
| 91 |
+
return 0.0 # Grayscale images have no saturation
|
| 92 |
+
|
| 93 |
+
# Convert to HSV and calculate mean saturation
|
| 94 |
+
hsv = cv2.cvtColor(image_array, cv2.COLOR_RGB2HSV)
|
| 95 |
+
return np.mean(hsv[:, :, 1])
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def calculate_entropy(image_array):
|
| 99 |
+
"""
|
| 100 |
+
Calculate image entropy as a measure of detail/complexity.
|
| 101 |
+
Higher values indicate more complex images.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
image_array: numpy array of the image
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
float: entropy score
|
| 108 |
+
"""
|
| 109 |
+
if len(image_array.shape) == 3:
|
| 110 |
+
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
|
| 111 |
+
else:
|
| 112 |
+
gray = image_array
|
| 113 |
+
|
| 114 |
+
return shannon_entropy(gray)
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def detect_compression_artifacts(image_array):
|
| 118 |
+
"""
|
| 119 |
+
Detect JPEG compression artifacts.
|
| 120 |
+
Higher values indicate more artifacts.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
image_array: numpy array of the image
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
float: artifact score
|
| 127 |
+
"""
|
| 128 |
+
if len(image_array.shape) == 3:
|
| 129 |
+
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
|
| 130 |
+
else:
|
| 131 |
+
gray = image_array
|
| 132 |
+
|
| 133 |
+
# Apply edge detection to find blocky artifacts
|
| 134 |
+
edges = cv2.Canny(gray, 100, 200)
|
| 135 |
+
return np.mean(edges) / 255.0
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def calculate_dynamic_range(image_array):
|
| 139 |
+
"""
|
| 140 |
+
Calculate dynamic range of the image.
|
| 141 |
+
Higher values indicate better use of available intensity range.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
image_array: numpy array of the image
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
float: dynamic range score
|
| 148 |
+
"""
|
| 149 |
+
if len(image_array.shape) == 3:
|
| 150 |
+
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
|
| 151 |
+
else:
|
| 152 |
+
gray = image_array
|
| 153 |
+
|
| 154 |
+
p1 = np.percentile(gray, 1)
|
| 155 |
+
p99 = np.percentile(gray, 99)
|
| 156 |
+
return (p99 - p1) / 255.0
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def calculate_all_metrics(image_path):
|
| 160 |
+
"""
|
| 161 |
+
Calculate all technical metrics for an image.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
image_path: path to the image file
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
dict: dictionary with all metric scores
|
| 168 |
+
"""
|
| 169 |
+
# Load image with PIL for metadata
|
| 170 |
+
pil_image = Image.open(image_path)
|
| 171 |
+
|
| 172 |
+
# Convert to numpy array for OpenCV processing
|
| 173 |
+
image_array = np.array(pil_image)
|
| 174 |
+
|
| 175 |
+
# Calculate all metrics
|
| 176 |
+
metrics = {
|
| 177 |
+
'sharpness': TechnicalMetrics.calculate_sharpness(image_array),
|
| 178 |
+
'noise': TechnicalMetrics.calculate_noise(image_array),
|
| 179 |
+
'contrast': TechnicalMetrics.calculate_contrast(image_array),
|
| 180 |
+
'saturation': TechnicalMetrics.calculate_saturation(image_array),
|
| 181 |
+
'entropy': TechnicalMetrics.calculate_entropy(image_array),
|
| 182 |
+
'compression_artifacts': TechnicalMetrics.detect_compression_artifacts(image_array),
|
| 183 |
+
'dynamic_range': TechnicalMetrics.calculate_dynamic_range(image_array),
|
| 184 |
+
'resolution': f"{pil_image.width}x{pil_image.height}",
|
| 185 |
+
'aspect_ratio': pil_image.width / pil_image.height if pil_image.height > 0 else 0,
|
| 186 |
+
'file_size_kb': pil_image.fp.tell() / 1024 if hasattr(pil_image.fp, 'tell') else 0,
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
return metrics
|
modules/visualizer.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module for visualizing image evaluation results and creating comparison tables.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 10 |
+
import os
|
| 11 |
+
import io
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import base64
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Visualizer:
|
| 17 |
+
"""Class for visualizing image evaluation results."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, output_dir='./results'):
|
| 20 |
+
"""
|
| 21 |
+
Initialize visualizer with output directory.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
output_dir: directory to save visualization results
|
| 25 |
+
"""
|
| 26 |
+
self.output_dir = output_dir
|
| 27 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
# Set up color schemes
|
| 30 |
+
self.setup_colors()
|
| 31 |
+
|
| 32 |
+
def setup_colors(self):
|
| 33 |
+
"""Set up color schemes for visualizations."""
|
| 34 |
+
# Custom colormap for heatmaps
|
| 35 |
+
self.cmap = LinearSegmentedColormap.from_list(
|
| 36 |
+
'custom_cmap', ['#FF5E5B', '#FFED66', '#00CEFF', '#0089BA', '#008F7A'], N=256
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Color palette for bar charts
|
| 40 |
+
self.palette = sns.color_palette("viridis", 10)
|
| 41 |
+
|
| 42 |
+
# Set Seaborn style
|
| 43 |
+
sns.set_style("whitegrid")
|
| 44 |
+
|
| 45 |
+
def create_comparison_table(self, results_dict, metrics_list=None):
|
| 46 |
+
"""
|
| 47 |
+
Create a comparison table from evaluation results.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
results_dict: dictionary with model names as keys and evaluation results as values
|
| 51 |
+
metrics_list: list of metrics to include in the table (if None, include all)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
pandas.DataFrame: comparison table
|
| 55 |
+
"""
|
| 56 |
+
# Initialize empty dataframe
|
| 57 |
+
df = pd.DataFrame()
|
| 58 |
+
|
| 59 |
+
# Process each model's results
|
| 60 |
+
for model_name, model_results in results_dict.items():
|
| 61 |
+
# Create a row for this model
|
| 62 |
+
model_row = {'Model': model_name}
|
| 63 |
+
|
| 64 |
+
# Add metrics to the row
|
| 65 |
+
for metric_name, metric_value in model_results.items():
|
| 66 |
+
if metrics_list is None or metric_name in metrics_list:
|
| 67 |
+
# Format numeric values to 2 decimal places
|
| 68 |
+
if isinstance(metric_value, (int, float)):
|
| 69 |
+
model_row[metric_name] = round(metric_value, 2)
|
| 70 |
+
else:
|
| 71 |
+
model_row[metric_name] = metric_value
|
| 72 |
+
|
| 73 |
+
# Append to dataframe
|
| 74 |
+
df = pd.concat([df, pd.DataFrame([model_row])], ignore_index=True)
|
| 75 |
+
|
| 76 |
+
# Set Model as index
|
| 77 |
+
if not df.empty:
|
| 78 |
+
df.set_index('Model', inplace=True)
|
| 79 |
+
|
| 80 |
+
return df
|
| 81 |
+
|
| 82 |
+
def plot_metric_comparison(self, df, metric_name, title=None, figsize=(10, 6)):
|
| 83 |
+
"""
|
| 84 |
+
Create a bar chart comparing models on a specific metric.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
df: pandas DataFrame with comparison data
|
| 88 |
+
metric_name: name of the metric to plot
|
| 89 |
+
title: optional custom title
|
| 90 |
+
figsize: figure size as (width, height)
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
str: path to saved figure
|
| 94 |
+
"""
|
| 95 |
+
if metric_name not in df.columns:
|
| 96 |
+
raise ValueError(f"Metric '{metric_name}' not found in dataframe")
|
| 97 |
+
|
| 98 |
+
# Create figure
|
| 99 |
+
plt.figure(figsize=figsize)
|
| 100 |
+
|
| 101 |
+
# Create bar chart
|
| 102 |
+
ax = sns.barplot(x=df.index, y=df[metric_name], palette=self.palette)
|
| 103 |
+
|
| 104 |
+
# Set title and labels
|
| 105 |
+
if title:
|
| 106 |
+
plt.title(title, fontsize=14)
|
| 107 |
+
else:
|
| 108 |
+
plt.title(f"Model Comparison: {metric_name}", fontsize=14)
|
| 109 |
+
|
| 110 |
+
plt.xlabel("Model", fontsize=12)
|
| 111 |
+
plt.ylabel(metric_name, fontsize=12)
|
| 112 |
+
|
| 113 |
+
# Rotate x-axis labels for better readability
|
| 114 |
+
plt.xticks(rotation=45, ha='right')
|
| 115 |
+
|
| 116 |
+
# Add value labels on top of bars
|
| 117 |
+
for i, v in enumerate(df[metric_name]):
|
| 118 |
+
ax.text(i, v + 0.1, str(round(v, 2)), ha='center')
|
| 119 |
+
|
| 120 |
+
plt.tight_layout()
|
| 121 |
+
|
| 122 |
+
# Save figure
|
| 123 |
+
output_path = os.path.join(self.output_dir, f"{metric_name}_comparison.png")
|
| 124 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 125 |
+
plt.close()
|
| 126 |
+
|
| 127 |
+
return output_path
|
| 128 |
+
|
| 129 |
+
def plot_radar_chart(self, df, metrics_list, title=None, figsize=(10, 8)):
|
| 130 |
+
"""
|
| 131 |
+
Create a radar chart comparing models across multiple metrics.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
df: pandas DataFrame with comparison data
|
| 135 |
+
metrics_list: list of metrics to include in the radar chart
|
| 136 |
+
title: optional custom title
|
| 137 |
+
figsize: figure size as (width, height)
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
str: path to saved figure
|
| 141 |
+
"""
|
| 142 |
+
# Filter metrics that exist in the dataframe
|
| 143 |
+
metrics = [m for m in metrics_list if m in df.columns]
|
| 144 |
+
|
| 145 |
+
if not metrics:
|
| 146 |
+
raise ValueError("None of the specified metrics found in dataframe")
|
| 147 |
+
|
| 148 |
+
# Number of metrics
|
| 149 |
+
N = len(metrics)
|
| 150 |
+
|
| 151 |
+
# Create figure
|
| 152 |
+
fig = plt.figure(figsize=figsize)
|
| 153 |
+
ax = fig.add_subplot(111, polar=True)
|
| 154 |
+
|
| 155 |
+
# Compute angle for each metric
|
| 156 |
+
angles = [n / float(N) * 2 * np.pi for n in range(N)]
|
| 157 |
+
angles += angles[:1] # Close the loop
|
| 158 |
+
|
| 159 |
+
# Plot each model
|
| 160 |
+
for i, model in enumerate(df.index):
|
| 161 |
+
values = df.loc[model, metrics].values.flatten().tolist()
|
| 162 |
+
values += values[:1] # Close the loop
|
| 163 |
+
|
| 164 |
+
# Plot values
|
| 165 |
+
ax.plot(angles, values, linewidth=2, linestyle='solid', label=model, color=self.palette[i % len(self.palette)])
|
| 166 |
+
ax.fill(angles, values, alpha=0.1, color=self.palette[i % len(self.palette)])
|
| 167 |
+
|
| 168 |
+
# Set labels
|
| 169 |
+
plt.xticks(angles[:-1], metrics, size=12)
|
| 170 |
+
|
| 171 |
+
# Set y-axis limits
|
| 172 |
+
ax.set_ylim(0, 10)
|
| 173 |
+
|
| 174 |
+
# Add legend
|
| 175 |
+
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
|
| 176 |
+
|
| 177 |
+
# Set title
|
| 178 |
+
if title:
|
| 179 |
+
plt.title(title, size=16, y=1.1)
|
| 180 |
+
else:
|
| 181 |
+
plt.title("Model Comparison Across Metrics", size=16, y=1.1)
|
| 182 |
+
|
| 183 |
+
# Save figure
|
| 184 |
+
output_path = os.path.join(self.output_dir, "radar_comparison.png")
|
| 185 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 186 |
+
plt.close()
|
| 187 |
+
|
| 188 |
+
return output_path
|
| 189 |
+
|
| 190 |
+
def plot_heatmap(self, df, title=None, figsize=(12, 8)):
|
| 191 |
+
"""
|
| 192 |
+
Create a heatmap of all metrics across models.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
df: pandas DataFrame with comparison data
|
| 196 |
+
title: optional custom title
|
| 197 |
+
figsize: figure size as (width, height)
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
str: path to saved figure
|
| 201 |
+
"""
|
| 202 |
+
# Create figure
|
| 203 |
+
plt.figure(figsize=figsize)
|
| 204 |
+
|
| 205 |
+
# Create heatmap
|
| 206 |
+
ax = sns.heatmap(df, annot=True, cmap=self.cmap, fmt=".2f", linewidths=.5)
|
| 207 |
+
|
| 208 |
+
# Set title
|
| 209 |
+
if title:
|
| 210 |
+
plt.title(title, fontsize=16)
|
| 211 |
+
else:
|
| 212 |
+
plt.title("Model Comparison Heatmap", fontsize=16)
|
| 213 |
+
|
| 214 |
+
plt.tight_layout()
|
| 215 |
+
|
| 216 |
+
# Save figure
|
| 217 |
+
output_path = os.path.join(self.output_dir, "comparison_heatmap.png")
|
| 218 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 219 |
+
plt.close()
|
| 220 |
+
|
| 221 |
+
return output_path
|
| 222 |
+
|
| 223 |
+
def plot_prompt_performance(self, prompt_results, metric_name, top_n=5, figsize=(12, 8)):
|
| 224 |
+
"""
|
| 225 |
+
Create a grouped bar chart showing model performance on different prompts.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
prompt_results: dictionary with prompts as keys and model results as values
|
| 229 |
+
metric_name: name of the metric to plot
|
| 230 |
+
top_n: number of top prompts to include
|
| 231 |
+
figsize: figure size as (width, height)
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
str: path to saved figure
|
| 235 |
+
"""
|
| 236 |
+
# Create dataframe from results
|
| 237 |
+
data = []
|
| 238 |
+
for prompt, models_data in prompt_results.items():
|
| 239 |
+
for model, metrics in models_data.items():
|
| 240 |
+
if metric_name in metrics:
|
| 241 |
+
data.append({
|
| 242 |
+
'Prompt': prompt,
|
| 243 |
+
'Model': model,
|
| 244 |
+
metric_name: metrics[metric_name]
|
| 245 |
+
})
|
| 246 |
+
|
| 247 |
+
df = pd.DataFrame(data)
|
| 248 |
+
|
| 249 |
+
if df.empty:
|
| 250 |
+
raise ValueError(f"No data found for metric '{metric_name}'")
|
| 251 |
+
|
| 252 |
+
# Get top N prompts by average metric value
|
| 253 |
+
top_prompts = df.groupby('Prompt')[metric_name].mean().nlargest(top_n).index.tolist()
|
| 254 |
+
df_filtered = df[df['Prompt'].isin(top_prompts)]
|
| 255 |
+
|
| 256 |
+
# Create figure
|
| 257 |
+
plt.figure(figsize=figsize)
|
| 258 |
+
|
| 259 |
+
# Create grouped bar chart
|
| 260 |
+
ax = sns.barplot(x='Prompt', y=metric_name, hue='Model', data=df_filtered, palette=self.palette)
|
| 261 |
+
|
| 262 |
+
# Set title and labels
|
| 263 |
+
plt.title(f"Model Performance by Prompt: {metric_name}", fontsize=14)
|
| 264 |
+
plt.xlabel("Prompt", fontsize=12)
|
| 265 |
+
plt.ylabel(metric_name, fontsize=12)
|
| 266 |
+
|
| 267 |
+
# Rotate x-axis labels for better readability
|
| 268 |
+
plt.xticks(rotation=45, ha='right')
|
| 269 |
+
|
| 270 |
+
# Adjust legend
|
| 271 |
+
plt.legend(title="Model", bbox_to_anchor=(1.05, 1), loc='upper left')
|
| 272 |
+
|
| 273 |
+
plt.tight_layout()
|
| 274 |
+
|
| 275 |
+
# Save figure
|
| 276 |
+
output_path = os.path.join(self.output_dir, f"prompt_performance_{metric_name}.png")
|
| 277 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 278 |
+
plt.close()
|
| 279 |
+
|
| 280 |
+
return output_path
|
| 281 |
+
|
| 282 |
+
def create_image_grid(self, image_paths, titles=None, cols=3, figsize=(15, 15)):
|
| 283 |
+
"""
|
| 284 |
+
Create a grid of images for visual comparison.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
image_paths: list of paths to images
|
| 288 |
+
titles: optional list of titles for each image
|
| 289 |
+
cols: number of columns in the grid
|
| 290 |
+
figsize: figure size as (width, height)
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
str: path to saved figure
|
| 294 |
+
"""
|
| 295 |
+
# Calculate number of rows needed
|
| 296 |
+
rows = (len(image_paths) + cols - 1) // cols
|
| 297 |
+
|
| 298 |
+
# Create figure
|
| 299 |
+
fig, axes = plt.subplots(rows, cols, figsize=figsize)
|
| 300 |
+
axes = axes.flatten()
|
| 301 |
+
|
| 302 |
+
# Add each image to the grid
|
| 303 |
+
for i, img_path in enumerate(image_paths):
|
| 304 |
+
if i < len(axes):
|
| 305 |
+
try:
|
| 306 |
+
img = Image.open(img_path)
|
| 307 |
+
axes[i].imshow(np.array(img))
|
| 308 |
+
|
| 309 |
+
# Add title if provided
|
| 310 |
+
if titles and i < len(titles):
|
| 311 |
+
axes[i].set_title(titles[i])
|
| 312 |
+
|
| 313 |
+
# Remove axis ticks
|
| 314 |
+
axes[i].set_xticks([])
|
| 315 |
+
axes[i].set_yticks([])
|
| 316 |
+
except Exception as e:
|
| 317 |
+
print(f"Error loading image {img_path}: {e}")
|
| 318 |
+
axes[i].text(0.5, 0.5, f"Error loading image", ha='center', va='center')
|
| 319 |
+
axes[i].set_xticks([])
|
| 320 |
+
axes[i].set_yticks([])
|
| 321 |
+
|
| 322 |
+
# Hide unused subplots
|
| 323 |
+
for j in range(len(image_paths), len(axes)):
|
| 324 |
+
axes[j].axis('off')
|
| 325 |
+
|
| 326 |
+
plt.tight_layout()
|
| 327 |
+
|
| 328 |
+
# Save figure
|
| 329 |
+
output_path = os.path.join(self.output_dir, "image_comparison_grid.png")
|
| 330 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 331 |
+
plt.close()
|
| 332 |
+
|
| 333 |
+
return output_path
|
| 334 |
+
|
| 335 |
+
def export_comparison_table(self, df, format='csv'):
|
| 336 |
+
"""
|
| 337 |
+
Export comparison table to file.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
df: pandas DataFrame with comparison data
|
| 341 |
+
format: export format ('csv', 'excel', or 'html')
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
str: path to saved file
|
| 345 |
+
"""
|
| 346 |
+
if format == 'csv':
|
| 347 |
+
output_path = os.path.join(self.output_dir, "comparison_table.csv")
|
| 348 |
+
df.to_csv(output_path)
|
| 349 |
+
elif format == 'excel':
|
| 350 |
+
output_path = os.path.join(self.output_dir, "comparison_table.xlsx")
|
| 351 |
+
df.to_excel(output_path)
|
| 352 |
+
elif format == 'html':
|
| 353 |
+
output_path = os.path.join(self.output_dir, "comparison_table.html")
|
| 354 |
+
df.to_html(output_path)
|
| 355 |
+
else:
|
| 356 |
+
raise ValueError(f"Unsupported format: {format}")
|
| 357 |
+
|
| 358 |
+
return output_path
|
| 359 |
+
|
| 360 |
+
def generate_html_report(self, comparison_table, image_paths, metrics_list):
|
| 361 |
+
"""
|
| 362 |
+
Generate a comprehensive HTML report with all visualizations.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
comparison_table: pandas DataFrame with comparison data
|
| 366 |
+
image_paths: dictionary of generated visualization image paths
|
| 367 |
+
metrics_list: list of metrics included in the analysis
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
str: path to saved HTML report
|
| 371 |
+
"""
|
| 372 |
+
# Create HTML content
|
| 373 |
+
html_content = f"""
|
| 374 |
+
<!DOCTYPE html>
|
| 375 |
+
<html>
|
| 376 |
+
<head>
|
| 377 |
+
<title>Image Model Evaluation Report</title>
|
| 378 |
+
<style>
|
| 379 |
+
body {{
|
| 380 |
+
font-family: Arial, sans-serif;
|
| 381 |
+
line-height: 1.6;
|
| 382 |
+
margin: 0;
|
| 383 |
+
padding: 20px;
|
| 384 |
+
color: #333;
|
| 385 |
+
}}
|
| 386 |
+
h1, h2, h3 {{
|
| 387 |
+
color: #2c3e50;
|
| 388 |
+
}}
|
| 389 |
+
.container {{
|
| 390 |
+
max-width: 1200px;
|
| 391 |
+
margin: 0 auto;
|
| 392 |
+
}}
|
| 393 |
+
table {{
|
| 394 |
+
border-collapse: collapse;
|
| 395 |
+
width: 100%;
|
| 396 |
+
margin-bottom: 20px;
|
| 397 |
+
}}
|
| 398 |
+
th, td {{
|
| 399 |
+
border: 1px solid #ddd;
|
| 400 |
+
padding: 8px;
|
| 401 |
+
text-align: left;
|
| 402 |
+
}}
|
| 403 |
+
th {{
|
| 404 |
+
background-color: #f2f2f2;
|
| 405 |
+
}}
|
| 406 |
+
tr:nth-child(even) {{
|
| 407 |
+
background-color: #f9f9f9;
|
| 408 |
+
}}
|
| 409 |
+
.visualization {{
|
| 410 |
+
margin: 20px 0;
|
| 411 |
+
text-align: center;
|
| 412 |
+
}}
|
| 413 |
+
.visualization img {{
|
| 414 |
+
max-width: 100%;
|
| 415 |
+
height: auto;
|
| 416 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
|
| 417 |
+
}}
|
| 418 |
+
.metrics-list {{
|
| 419 |
+
background-color: #f8f9fa;
|
| 420 |
+
padding: 15px;
|
| 421 |
+
border-radius: 5px;
|
| 422 |
+
margin-bottom: 20px;
|
| 423 |
+
}}
|
| 424 |
+
</style>
|
| 425 |
+
</head>
|
| 426 |
+
<body>
|
| 427 |
+
<div class="container">
|
| 428 |
+
<h1>Image Model Evaluation Report</h1>
|
| 429 |
+
|
| 430 |
+
<h2>Metrics Overview</h2>
|
| 431 |
+
<div class="metrics-list">
|
| 432 |
+
<h3>Metrics included in this analysis:</h3>
|
| 433 |
+
<ul>
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
# Add metrics list
|
| 437 |
+
for metric in metrics_list:
|
| 438 |
+
html_content += f" <li><strong>{metric}</strong></li>\n"
|
| 439 |
+
|
| 440 |
+
html_content += """
|
| 441 |
+
</ul>
|
| 442 |
+
</div>
|
| 443 |
+
|
| 444 |
+
<h2>Comparison Table</h2>
|
| 445 |
+
"""
|
| 446 |
+
|
| 447 |
+
# Add comparison table
|
| 448 |
+
html_content += comparison_table.to_html(classes="table table-striped")
|
| 449 |
+
|
| 450 |
+
# Add visualizations
|
| 451 |
+
html_content += """
|
| 452 |
+
<h2>Visualizations</h2>
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
for title, img_path in image_paths.items():
|
| 456 |
+
if os.path.exists(img_path):
|
| 457 |
+
# Convert image to base64 for embedding
|
| 458 |
+
with open(img_path, "rb") as img_file:
|
| 459 |
+
img_data = base64.b64encode(img_file.read()).decode('utf-8')
|
| 460 |
+
|
| 461 |
+
html_content += f"""
|
| 462 |
+
<div class="visualization">
|
| 463 |
+
<h3>{title}</h3>
|
| 464 |
+
<img src="data:image/png;base64,{img_data}" alt="{title}">
|
| 465 |
+
</div>
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
# Close HTML
|
| 469 |
+
html_content += """
|
| 470 |
+
</div>
|
| 471 |
+
</body>
|
| 472 |
+
</html>
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
# Save HTML report
|
| 476 |
+
output_path = os.path.join(self.output_dir, "evaluation_report.html")
|
| 477 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 478 |
+
f.write(html_content)
|
| 479 |
+
|
| 480 |
+
return output_path
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
pillow>=9.0.0
|
| 3 |
+
numpy>=1.20.0
|
| 4 |
+
pandas>=1.3.0
|
| 5 |
+
matplotlib>=3.5.0
|
| 6 |
+
seaborn>=0.11.0
|
| 7 |
+
scikit-image>=0.19.0
|
| 8 |
+
opencv-python>=4.5.0
|
| 9 |
+
torch>=2.0.0
|
| 10 |
+
torchvision>=0.15.0
|
| 11 |
+
transformers>=4.30.0
|
| 12 |
+
clip>=0.2.0
|
| 13 |
+
timm>=0.6.0
|
| 14 |
+
openpyxl>=3.0.0
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility modules for the Image Evaluator tool.
|
| 3 |
+
"""
|
utils/data_handling.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for data handling and export.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import csv
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def save_json(data, file_path):
|
| 13 |
+
"""
|
| 14 |
+
Save data to a JSON file.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
data: data to save
|
| 18 |
+
file_path: path to the output file
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
bool: True if successful, False otherwise
|
| 22 |
+
"""
|
| 23 |
+
try:
|
| 24 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 25 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 26 |
+
return True
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error saving JSON: {e}")
|
| 29 |
+
return False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_json(file_path):
|
| 33 |
+
"""
|
| 34 |
+
Load data from a JSON file.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
file_path: path to the JSON file
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
dict: loaded data, or None if an error occurred
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 44 |
+
return json.load(f)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Error loading JSON: {e}")
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def save_csv(data, file_path, headers=None):
|
| 51 |
+
"""
|
| 52 |
+
Save data to a CSV file.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
data: list of dictionaries or list of lists
|
| 56 |
+
file_path: path to the output file
|
| 57 |
+
headers: optional list of column headers
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
bool: True if successful, False otherwise
|
| 61 |
+
"""
|
| 62 |
+
try:
|
| 63 |
+
if isinstance(data, list) and len(data) > 0:
|
| 64 |
+
if isinstance(data[0], dict):
|
| 65 |
+
# List of dictionaries
|
| 66 |
+
if headers is None:
|
| 67 |
+
headers = list(data[0].keys())
|
| 68 |
+
|
| 69 |
+
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
| 70 |
+
writer = csv.DictWriter(f, fieldnames=headers)
|
| 71 |
+
writer.writeheader()
|
| 72 |
+
writer.writerows(data)
|
| 73 |
+
else:
|
| 74 |
+
# List of lists
|
| 75 |
+
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
| 76 |
+
writer = csv.writer(f)
|
| 77 |
+
if headers:
|
| 78 |
+
writer.writerow(headers)
|
| 79 |
+
writer.writerows(data)
|
| 80 |
+
return True
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Error saving CSV: {e}")
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def dataframe_to_formats(df, base_path, formats=None):
|
| 87 |
+
"""
|
| 88 |
+
Export a pandas DataFrame to multiple formats.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
df: pandas DataFrame
|
| 92 |
+
base_path: base path for output files (without extension)
|
| 93 |
+
formats: list of formats to export to ('csv', 'excel', 'html', 'json')
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
dict: dictionary with format names as keys and file paths as values
|
| 97 |
+
"""
|
| 98 |
+
if formats is None:
|
| 99 |
+
formats = ['csv', 'excel', 'html']
|
| 100 |
+
|
| 101 |
+
result = {}
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
for fmt in formats:
|
| 105 |
+
if fmt == 'csv':
|
| 106 |
+
file_path = f"{base_path}.csv"
|
| 107 |
+
df.to_csv(file_path)
|
| 108 |
+
result['csv'] = file_path
|
| 109 |
+
elif fmt == 'excel':
|
| 110 |
+
file_path = f"{base_path}.xlsx"
|
| 111 |
+
df.to_excel(file_path)
|
| 112 |
+
result['excel'] = file_path
|
| 113 |
+
elif fmt == 'html':
|
| 114 |
+
file_path = f"{base_path}.html"
|
| 115 |
+
df.to_html(file_path)
|
| 116 |
+
result['html'] = file_path
|
| 117 |
+
elif fmt == 'json':
|
| 118 |
+
file_path = f"{base_path}.json"
|
| 119 |
+
df.to_json(file_path, orient='records', indent=2)
|
| 120 |
+
result['json'] = file_path
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"Error exporting DataFrame: {e}")
|
| 123 |
+
|
| 124 |
+
return result
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def generate_timestamp():
|
| 128 |
+
"""
|
| 129 |
+
Generate a timestamp string for file naming.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
str: timestamp string
|
| 133 |
+
"""
|
| 134 |
+
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def create_results_filename(prefix="evaluation", extension=""):
|
| 138 |
+
"""
|
| 139 |
+
Create a filename for results with timestamp.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
prefix: prefix for the filename
|
| 143 |
+
extension: file extension (with or without dot)
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
str: filename with timestamp
|
| 147 |
+
"""
|
| 148 |
+
timestamp = generate_timestamp()
|
| 149 |
+
|
| 150 |
+
if extension:
|
| 151 |
+
if not extension.startswith('.'):
|
| 152 |
+
extension = f".{extension}"
|
| 153 |
+
return f"{prefix}_{timestamp}{extension}"
|
| 154 |
+
else:
|
| 155 |
+
return f"{prefix}_{timestamp}"
|
utils/image_processing.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for image processing and data handling.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
import tempfile
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_thumbnail(image_path, max_size=(200, 200)):
|
| 13 |
+
"""
|
| 14 |
+
Create a thumbnail of an image.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
image_path: path to the image file
|
| 18 |
+
max_size: maximum size of the thumbnail as (width, height)
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
PIL.Image: thumbnail image
|
| 22 |
+
"""
|
| 23 |
+
try:
|
| 24 |
+
image = Image.open(image_path)
|
| 25 |
+
image.thumbnail(max_size)
|
| 26 |
+
return image
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error creating thumbnail for {image_path}: {e}")
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_temp_directory():
|
| 33 |
+
"""
|
| 34 |
+
Create a temporary directory for storing intermediate files.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
str: path to the temporary directory
|
| 38 |
+
"""
|
| 39 |
+
temp_dir = tempfile.mkdtemp(prefix="image_evaluator_")
|
| 40 |
+
return temp_dir
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def cleanup_temp_directory(temp_dir):
|
| 44 |
+
"""
|
| 45 |
+
Clean up a temporary directory.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
temp_dir: path to the temporary directory
|
| 49 |
+
"""
|
| 50 |
+
if os.path.exists(temp_dir):
|
| 51 |
+
shutil.rmtree(temp_dir)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def ensure_directory(directory):
|
| 55 |
+
"""
|
| 56 |
+
Ensure that a directory exists, creating it if necessary.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
directory: path to the directory
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
str: path to the directory
|
| 63 |
+
"""
|
| 64 |
+
os.makedirs(directory, exist_ok=True)
|
| 65 |
+
return directory
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def is_valid_image(file_path):
|
| 69 |
+
"""
|
| 70 |
+
Check if a file is a valid image.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
file_path: path to the file
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
bool: True if the file is a valid image, False otherwise
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
with Image.open(file_path) as img:
|
| 80 |
+
img.verify()
|
| 81 |
+
return True
|
| 82 |
+
except:
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def convert_to_rgb(image_path):
|
| 87 |
+
"""
|
| 88 |
+
Convert an image to RGB mode if necessary.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
image_path: path to the image file
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
numpy.ndarray: RGB image array
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
image = Image.open(image_path)
|
| 98 |
+
if image.mode != 'RGB':
|
| 99 |
+
image = image.convert('RGB')
|
| 100 |
+
return np.array(image)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Error converting image to RGB: {e}")
|
| 103 |
+
return None
|