VOIDER commited on
Commit
9eb4c90
·
verified ·
1 Parent(s): 4bf8141

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import torch
4
+ import pandas as pd
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from sd_parsers import ParserManager
8
+ from torchvision import transforms
9
+ from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration
10
+ import lpips
11
+ import piq
12
+ from torchmetrics import ClippedCosineSimilarity
13
+ import plotly.express as px
14
+ import plotly.graph_objects as go
15
+
16
+ # --------------------
17
+ # Setup Models
18
+ # --------------------
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+ # CLIP for prompt alignment & aesthetics
22
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
23
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
24
+
25
+ # BLIP-2 for caption generation
26
+ blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
27
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(
28
+ "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
29
+ ).to(device)
30
+
31
+ # LPIPS for diversity
32
+ lpips_model = lpips.LPIPS(net='alex').to(device)
33
+
34
+ # IQA metrics (BRISQUE, NIQE)
35
+ # piq functions are stateless, import directly
36
+ # Aesthetic predictor: use CLIP image-text ("a beautiful photograph") as proxy
37
+
38
+ def extract_metadata(pil_img):
39
+ pm = ParserManager()
40
+ info = pm.parse(pil_img)
41
+ prompt = info.prompts[0].value if info.prompts else ''
42
+ model_name = info.model_name or ''
43
+ return prompt, model_name
44
+
45
+ # Transform
46
+ preprocess = transforms.Compose([
47
+ transforms.Resize((224, 224)),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
50
+ (0.26862954, 0.26130258, 0.27577711))
51
+ ])
52
+
53
+ def compute_clip_score(img, text):
54
+ inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device)
55
+ outputs = clip_model(**inputs)
56
+ score = torch.cosine_similarity(outputs.image_embeds, outputs.text_embeds)
57
+ return float((score.clamp(min=0) * 100).mean())
58
+
59
+ @torch.no_grad()
60
+ def compute_caption_similarity(img, prompt):
61
+ inputs = blip_processor(images=img, return_tensors="pt").to(device, torch.float16)
62
+ out = blip_model.generate(**inputs)
63
+ caption = blip_processor.decode(out[0], skip_special_tokens=True)
64
+ # similarity via CLIP
65
+ return compute_clip_score(img, caption)
66
+
67
+ @torch.no_grad()
68
+ def compute_iqa_metrics(img):
69
+ # convert to tensor
70
+ img_t = transforms.ToTensor()(img).unsqueeze(0).to(device)
71
+ brisque = float(piq.brisque(img_t).cpu())
72
+ niqe = float(piq.niqe(img_t).cpu())
73
+ return brisque, niqe
74
+
75
+ @torch.no_grad()
76
+ def compute_lpips_pair(img1, img2):
77
+ t1 = transforms.ToTensor()(img1).unsqueeze(0).to(device)
78
+ t2 = transforms.ToTensor()(img2).unsqueeze(0).to(device)
79
+ return float(lpips_model(t1, t2).cpu())
80
+
81
+ # --------------------
82
+ # Analysis Pipeline
83
+ # --------------------
84
+
85
+ def analyze_images(images):
86
+ records = []
87
+ imgs_by_model = {}
88
+
89
+ # extract metadata and compute per-image metrics
90
+ for img in images:
91
+ prompt, model = extract_metadata(img)
92
+ # basics
93
+ clip_score = compute_clip_score(img, prompt)
94
+ cap_sim = compute_caption_similarity(img, prompt)
95
+ brisque, niqe = compute_iqa_metrics(img)
96
+ # aesthetic proxy: CLIP with generic prompt
97
+ aest = compute_clip_score(img, "a beautiful high quality image")
98
+
99
+ records.append({
100
+ 'model': model,
101
+ 'prompt': prompt,
102
+ 'clip_score': clip_score,
103
+ 'caption_sim': cap_sim,
104
+ 'brisque': brisque,
105
+ 'niqe': niqe,
106
+ 'aesthetic': aest,
107
+ 'image': img
108
+ })
109
+ imgs_by_model.setdefault(model, []).append(img)
110
+
111
+ df = pd.DataFrame(records)
112
+
113
+ # diversity via LPIPS per model
114
+ diversity = {}
115
+ for model, imgs in imgs_by_model.items():
116
+ if len(imgs) < 2:
117
+ diversity[model] = 0.0
118
+ else:
119
+ pairs = []
120
+ for i in range(len(imgs)):
121
+ for j in range(i+1, len(imgs)):
122
+ pairs.append(compute_lpips_pair(imgs[i], imgs[j]))
123
+ diversity[model] = sum(pairs) / len(pairs)
124
+
125
+ # aggregate
126
+ agg = df.groupby('model').agg({
127
+ 'clip_score': ['mean'],
128
+ 'caption_sim': ['mean'],
129
+ 'brisque': ['mean'],
130
+ 'niqe': ['mean'],
131
+ 'aesthetic': ['mean']
132
+ })
133
+ agg.columns = ['_'.join(col) for col in agg.columns]
134
+ agg['diversity'] = pd.Series(diversity)
135
+ agg = agg.reset_index()
136
+
137
+ return df, agg
138
+
139
+ # --------------------
140
+ # Visualization Helpers
141
+ # --------------------
142
+
143
+ def plot_metrics(agg):
144
+ fig = px.bar(agg, x='model', y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'],
145
+ barmode='group', title='Сравнение моделей по метрикам')
146
+ return fig
147
+
148
+ # --------------------
149
+ # Gradio Interface
150
+ # --------------------
151
+
152
+ def run_analysis(files):
153
+ images = [Image.open(io.BytesIO(f.read())).convert('RGB') for f in files]
154
+ df, agg = analyze_images(images)
155
+
156
+ # plots
157
+ fig = plot_metrics(agg)
158
+
159
+ return df, fig
160
+
161
+ with gr.Blocks() as demo:
162
+ gr.Markdown("# AI Image Quality Evaluator")
163
+ gr.Markdown("Загрузите PNG-изображения, сгенерированные моделями ИИ, для анализа и сравнения моделей.")
164
+
165
+ with gr.Row():
166
+ input_files = gr.File(file_count="multiple", label="Выберите PNG файлы")
167
+ output_table = gr.DataFrame(headers=["model", "clip_score_mean", "caption_sim_mean", "brisque_mean", "niqe_mean", "aesthetic_mean", "diversity"], label="Сводная таблица")
168
+
169
+ plot_output = gr.Plot(label="График метрик")
170
+
171
+ run_btn = gr.Button("Запустить анализ")
172
+ run_btn.click(run_analysis, inputs=[input_files], outputs=[output_table, plot_output])
173
+
174
+ if __name__ == "__main__":
175
+ demo.launch(server_name='0.0.0.0', share=False)