|
|
import os |
|
|
import io |
|
|
import torch |
|
|
import pandas as pd |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from sd_parsers import ParserManager |
|
|
from torchvision import transforms |
|
|
from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig |
|
|
import lpips |
|
|
import piq |
|
|
import plotly.express as px |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) |
|
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
|
|
|
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") |
|
|
if torch.cuda.is_available(): |
|
|
bnb_config = BitsAndBytesConfig(load_in_8bit=True) |
|
|
blip_model = Blip2ForConditionalGeneration.from_pretrained( |
|
|
"Salesforce/blip2-flan-t5-xl", |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto" |
|
|
) |
|
|
else: |
|
|
blip_model = Blip2ForConditionalGeneration.from_pretrained( |
|
|
"Salesforce/blip2-flan-t5-xl", |
|
|
torch_dtype=torch.float16 |
|
|
).to(device) |
|
|
|
|
|
|
|
|
lpips_model = lpips.LPIPS(net='alex').to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_metadata(file): |
|
|
"""Extract prompt and model name using sd-parsers from file path.""" |
|
|
parser = ParserManager() |
|
|
info = parser.parse(file.name) |
|
|
prompt = info.prompts[0].value if info.prompts else '' |
|
|
|
|
|
model_name = '' |
|
|
if hasattr(info, 'models') and info.models: |
|
|
|
|
|
first = next(iter(info.models)) |
|
|
model_name = first.name if hasattr(first, 'name') else str(first) |
|
|
return prompt, model_name |
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
(0.48145466, 0.4578275, 0.40821073), |
|
|
(0.26862954, 0.26130258, 0.27577711) |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_clip_score(img: Image.Image, text: str) -> float: |
|
|
inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device) |
|
|
outputs = clip_model(**inputs) |
|
|
score = torch.cosine_similarity(outputs.image_embeds, outputs.text_embeds) |
|
|
return float((score.clamp(min=0) * 100).mean()) |
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_caption_similarity(img: Image.Image, prompt: str) -> float: |
|
|
inputs = blip_processor(images=img, return_tensors="pt").to(device) |
|
|
out = blip_model.generate(**inputs) |
|
|
caption = blip_processor.decode(out[0], skip_special_tokens=True) |
|
|
return compute_clip_score(img, caption) |
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_iqa_metrics(img: Image.Image): |
|
|
tensor = transforms.ToTensor()(img).unsqueeze(0).to(device) |
|
|
brisque = float(piq.brisque(tensor).cpu()) |
|
|
niqe = float(piq.niqe(tensor).cpu()) |
|
|
return brisque, niqe |
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_lpips_pair(img1: Image.Image, img2: Image.Image) -> float: |
|
|
t1 = transforms.ToTensor()(img1).unsqueeze(0).to(device) |
|
|
t2 = transforms.ToTensor()(img2).unsqueeze(0).to(device) |
|
|
return float(lpips_model(t1, t2).cpu()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def analyze_images(files): |
|
|
records = [] |
|
|
imgs_by_model = {} |
|
|
|
|
|
for f in files: |
|
|
img = Image.open(f.name).convert('RGB') |
|
|
prompt, model = extract_metadata(f) |
|
|
|
|
|
cs = compute_clip_score(img, prompt) |
|
|
cap_sim = compute_caption_similarity(img, prompt) |
|
|
brisque, niqe = compute_iqa_metrics(img) |
|
|
aesthetic = compute_clip_score(img, "a beautiful high quality image") |
|
|
|
|
|
records.append({ |
|
|
'model': model, |
|
|
'prompt': prompt, |
|
|
'clip_score': cs, |
|
|
'caption_sim': cap_sim, |
|
|
'brisque': brisque, |
|
|
'niqe': niqe, |
|
|
'aesthetic': aesthetic |
|
|
}) |
|
|
imgs_by_model.setdefault(model, []).append(img) |
|
|
|
|
|
df = pd.DataFrame(records) |
|
|
|
|
|
diversity = {} |
|
|
for model, imgs in imgs_by_model.items(): |
|
|
if len(imgs) < 2: |
|
|
diversity[model] = 0.0 |
|
|
else: |
|
|
pairs = [compute_lpips_pair(imgs[i], imgs[j]) |
|
|
for i in range(len(imgs)) for j in range(i+1, len(imgs))] |
|
|
diversity[model] = sum(pairs) / len(pairs) |
|
|
|
|
|
agg = df.groupby('model').agg( |
|
|
clip_score_mean=('clip_score', 'mean'), |
|
|
caption_sim_mean=('caption_sim', 'mean'), |
|
|
brisque_mean=('brisque', 'mean'), |
|
|
niqe_mean=('niqe', 'mean'), |
|
|
aesthetic_mean=('aesthetic', 'mean') |
|
|
).reset_index() |
|
|
agg['diversity'] = agg['model'].map(diversity) |
|
|
|
|
|
return df, agg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_metrics(agg: pd.DataFrame): |
|
|
return px.bar( |
|
|
agg, |
|
|
x='model', |
|
|
y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'], |
|
|
barmode='group', |
|
|
title='Сравнение моделей по метрикам' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_analysis(files): |
|
|
df, agg = analyze_images(files) |
|
|
fig = plot_metrics(agg) |
|
|
return df, fig |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# AI Image Quality Evaluator") |
|
|
gr.Markdown("Загрузите PNG-изображения (с EXIF-метаданными SD) для анализа и сравнения моделей.") |
|
|
|
|
|
with gr.Row(): |
|
|
input_files = gr.File(file_count="multiple", label="Выберите PNG файлы") |
|
|
output_table = gr.Dataframe( |
|
|
headers=[ |
|
|
"model", "clip_score_mean", "caption_sim_mean", "brisque_mean", |
|
|
"niqe_mean", "aesthetic_mean", "diversity" |
|
|
], |
|
|
label="Сводная таблица" |
|
|
) |
|
|
|
|
|
plot_output = gr.Plot(label="График метрик") |
|
|
|
|
|
run_btn = gr.Button("Запустить анализ") |
|
|
run_btn.click(run_analysis, inputs=[input_files], outputs=[output_table, plot_output]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name='0.0.0.0', share=False) |
|
|
|