import os import io import torch import pandas as pd import gradio as gr from PIL import Image from sd_parsers import ParserManager from torchvision import transforms from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig import lpips import piq import plotly.express as px # -------------------- # Setup Models # -------------------- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # CLIP for prompt alignment & aesthetics clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # BLIP-2 for caption generation: 8-bit if GPU available, else half precision blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") if torch.cuda.is_available(): bnb_config = BitsAndBytesConfig(load_in_8bit=True) blip_model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-flan-t5-xl", quantization_config=bnb_config, device_map="auto" ) else: blip_model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16 ).to(device) # LPIPS for diversity lpips_model = lpips.LPIPS(net='alex').to(device) # -------------------- # Helper Functions # -------------------- def extract_metadata(file): """Extract prompt and model name using sd-parsers from file path.""" parser = ParserManager() info = parser.parse(file.name) prompt = info.prompts[0].value if info.prompts else '' # info.models may be a set or list of model identifiers model_name = '' if hasattr(info, 'models') and info.models: # pick one representative model first = next(iter(info.models)) model_name = first.name if hasattr(first, 'name') else str(first) return prompt, model_name # Image preprocessing transform preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) ) ]) # -------------------- # Metric Computations # -------------------- def compute_clip_score(img: Image.Image, text: str) -> float: inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device) outputs = clip_model(**inputs) score = torch.cosine_similarity(outputs.image_embeds, outputs.text_embeds) return float((score.clamp(min=0) * 100).mean()) @torch.no_grad() def compute_caption_similarity(img: Image.Image, prompt: str) -> float: inputs = blip_processor(images=img, return_tensors="pt").to(device) out = blip_model.generate(**inputs) caption = blip_processor.decode(out[0], skip_special_tokens=True) return compute_clip_score(img, caption) @torch.no_grad() def compute_iqa_metrics(img: Image.Image): tensor = transforms.ToTensor()(img).unsqueeze(0).to(device) brisque = float(piq.brisque(tensor).cpu()) niqe = float(piq.niqe(tensor).cpu()) return brisque, niqe @torch.no_grad() def compute_lpips_pair(img1: Image.Image, img2: Image.Image) -> float: t1 = transforms.ToTensor()(img1).unsqueeze(0).to(device) t2 = transforms.ToTensor()(img2).unsqueeze(0).to(device) return float(lpips_model(t1, t2).cpu()) # -------------------- # Analysis Pipeline # -------------------- def analyze_images(files): records = [] imgs_by_model = {} for f in files: img = Image.open(f.name).convert('RGB') prompt, model = extract_metadata(f) cs = compute_clip_score(img, prompt) cap_sim = compute_caption_similarity(img, prompt) brisque, niqe = compute_iqa_metrics(img) aesthetic = compute_clip_score(img, "a beautiful high quality image") records.append({ 'model': model, 'prompt': prompt, 'clip_score': cs, 'caption_sim': cap_sim, 'brisque': brisque, 'niqe': niqe, 'aesthetic': aesthetic }) imgs_by_model.setdefault(model, []).append(img) df = pd.DataFrame(records) diversity = {} for model, imgs in imgs_by_model.items(): if len(imgs) < 2: diversity[model] = 0.0 else: pairs = [compute_lpips_pair(imgs[i], imgs[j]) for i in range(len(imgs)) for j in range(i+1, len(imgs))] diversity[model] = sum(pairs) / len(pairs) agg = df.groupby('model').agg( clip_score_mean=('clip_score', 'mean'), caption_sim_mean=('caption_sim', 'mean'), brisque_mean=('brisque', 'mean'), niqe_mean=('niqe', 'mean'), aesthetic_mean=('aesthetic', 'mean') ).reset_index() agg['diversity'] = agg['model'].map(diversity) return df, agg # -------------------- # Visualization # -------------------- def plot_metrics(agg: pd.DataFrame): return px.bar( agg, x='model', y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'], barmode='group', title='Сравнение моделей по метрикам' ) # -------------------- # Gradio Interface # -------------------- def run_analysis(files): df, agg = analyze_images(files) fig = plot_metrics(agg) return df, fig with gr.Blocks() as demo: gr.Markdown("# AI Image Quality Evaluator") gr.Markdown("Загрузите PNG-изображения (с EXIF-метаданными SD) для анализа и сравнения моделей.") with gr.Row(): input_files = gr.File(file_count="multiple", label="Выберите PNG файлы") output_table = gr.Dataframe( headers=[ "model", "clip_score_mean", "caption_sim_mean", "brisque_mean", "niqe_mean", "aesthetic_mean", "diversity" ], label="Сводная таблица" ) plot_output = gr.Plot(label="График метрик") run_btn = gr.Button("Запустить анализ") run_btn.click(run_analysis, inputs=[input_files], outputs=[output_table, plot_output]) if __name__ == "__main__": demo.launch(server_name='0.0.0.0', share=False)