img-eval / app.py
VOIDER's picture
Update app.py
a16cc05 verified
import os
import io
import torch
import pandas as pd
import gradio as gr
from PIL import Image
from sd_parsers import ParserManager
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig
import lpips
import piq
import plotly.express as px
# --------------------
# Setup Models
# --------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# CLIP for prompt alignment & aesthetics
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# BLIP-2 for caption generation: 8-bit if GPU available, else half precision
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
if torch.cuda.is_available():
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xl",
quantization_config=bnb_config,
device_map="auto"
)
else:
blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xl",
torch_dtype=torch.float16
).to(device)
# LPIPS for diversity
lpips_model = lpips.LPIPS(net='alex').to(device)
# --------------------
# Helper Functions
# --------------------
def extract_metadata(file):
"""Extract prompt and model name using sd-parsers from file path."""
parser = ParserManager()
info = parser.parse(file.name)
prompt = info.prompts[0].value if info.prompts else ''
# info.models may be a set or list of model identifiers
model_name = ''
if hasattr(info, 'models') and info.models:
# pick one representative model
first = next(iter(info.models))
model_name = first.name if hasattr(first, 'name') else str(first)
return prompt, model_name
# Image preprocessing transform
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)
)
])
# --------------------
# Metric Computations
# --------------------
def compute_clip_score(img: Image.Image, text: str) -> float:
inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device)
outputs = clip_model(**inputs)
score = torch.cosine_similarity(outputs.image_embeds, outputs.text_embeds)
return float((score.clamp(min=0) * 100).mean())
@torch.no_grad()
def compute_caption_similarity(img: Image.Image, prompt: str) -> float:
inputs = blip_processor(images=img, return_tensors="pt").to(device)
out = blip_model.generate(**inputs)
caption = blip_processor.decode(out[0], skip_special_tokens=True)
return compute_clip_score(img, caption)
@torch.no_grad()
def compute_iqa_metrics(img: Image.Image):
tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
brisque = float(piq.brisque(tensor).cpu())
niqe = float(piq.niqe(tensor).cpu())
return brisque, niqe
@torch.no_grad()
def compute_lpips_pair(img1: Image.Image, img2: Image.Image) -> float:
t1 = transforms.ToTensor()(img1).unsqueeze(0).to(device)
t2 = transforms.ToTensor()(img2).unsqueeze(0).to(device)
return float(lpips_model(t1, t2).cpu())
# --------------------
# Analysis Pipeline
# --------------------
def analyze_images(files):
records = []
imgs_by_model = {}
for f in files:
img = Image.open(f.name).convert('RGB')
prompt, model = extract_metadata(f)
cs = compute_clip_score(img, prompt)
cap_sim = compute_caption_similarity(img, prompt)
brisque, niqe = compute_iqa_metrics(img)
aesthetic = compute_clip_score(img, "a beautiful high quality image")
records.append({
'model': model,
'prompt': prompt,
'clip_score': cs,
'caption_sim': cap_sim,
'brisque': brisque,
'niqe': niqe,
'aesthetic': aesthetic
})
imgs_by_model.setdefault(model, []).append(img)
df = pd.DataFrame(records)
diversity = {}
for model, imgs in imgs_by_model.items():
if len(imgs) < 2:
diversity[model] = 0.0
else:
pairs = [compute_lpips_pair(imgs[i], imgs[j])
for i in range(len(imgs)) for j in range(i+1, len(imgs))]
diversity[model] = sum(pairs) / len(pairs)
agg = df.groupby('model').agg(
clip_score_mean=('clip_score', 'mean'),
caption_sim_mean=('caption_sim', 'mean'),
brisque_mean=('brisque', 'mean'),
niqe_mean=('niqe', 'mean'),
aesthetic_mean=('aesthetic', 'mean')
).reset_index()
agg['diversity'] = agg['model'].map(diversity)
return df, agg
# --------------------
# Visualization
# --------------------
def plot_metrics(agg: pd.DataFrame):
return px.bar(
agg,
x='model',
y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'],
barmode='group',
title='Сравнение моделей по метрикам'
)
# --------------------
# Gradio Interface
# --------------------
def run_analysis(files):
df, agg = analyze_images(files)
fig = plot_metrics(agg)
return df, fig
with gr.Blocks() as demo:
gr.Markdown("# AI Image Quality Evaluator")
gr.Markdown("Загрузите PNG-изображения (с EXIF-метаданными SD) для анализа и сравнения моделей.")
with gr.Row():
input_files = gr.File(file_count="multiple", label="Выберите PNG файлы")
output_table = gr.Dataframe(
headers=[
"model", "clip_score_mean", "caption_sim_mean", "brisque_mean",
"niqe_mean", "aesthetic_mean", "diversity"
],
label="Сводная таблица"
)
plot_output = gr.Plot(label="График метрик")
run_btn = gr.Button("Запустить анализ")
run_btn.click(run_analysis, inputs=[input_files], outputs=[output_table, plot_output])
if __name__ == "__main__":
demo.launch(server_name='0.0.0.0', share=False)