Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import torch | |
| import spaces | |
| import pandas as pd | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModel | |
| # --- ИНИЦИАЛИЗАЦИЯ --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_id = "yuvalkirstain/PickScore_v1" | |
| processor_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" | |
| print(f"Loading models to {device}...") | |
| processor = AutoProcessor.from_pretrained(processor_id) | |
| model = AutoModel.from_pretrained(model_id).eval().to(device) | |
| def ensure_tensor(output): | |
| if isinstance(output, torch.Tensor): return output | |
| if hasattr(output, 'image_embeds') and output.image_embeds is not None: return output.image_embeds | |
| if hasattr(output, 'text_embeds') and output.text_embeds is not None: return output.text_embeds | |
| if hasattr(output, 'pooler_output') and output.pooler_output is not None: return output.pooler_output | |
| return output[0] | |
| def predict_absolute(prompt, image1, image2): | |
| if image1 is None or image2 is None or not prompt.strip(): | |
| return None | |
| try: | |
| images = [image1, image2] | |
| inputs = processor(text=[prompt], images=images, padding=True, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| img_output = model.get_image_features(pixel_values=inputs['pixel_values']) | |
| txt_output = model.get_text_features(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) | |
| img_feats = ensure_tensor(img_output) | |
| txt_feats = ensure_tensor(txt_output) | |
| img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True) | |
| txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True) | |
| scores = (txt_feats @ img_feats.T)[0] * model.logit_scale.exp() | |
| results = [] | |
| for s in scores.cpu().tolist(): | |
| val = (s - 15) * 5 | |
| clamped_score = max(1, min(100, val)) | |
| results.append(round(clamped_score, 1)) | |
| return results | |
| except Exception as e: | |
| print(f"Internal Error: {e}") | |
| return [0.0, 0.0] | |
| def multi_predict(p1, p2, p3, p4, p5, p6, p7, p8, img1, img2): | |
| # Собираем все 8 промптов в один список | |
| raw_prompts = [p1, p2, p3, p4, p5, p6, p7, p8] | |
| prompts = [p for p in raw_prompts if p and p.strip() != ""] | |
| if not prompts or img1 is None or img2 is None: | |
| return pd.DataFrame([["Ошибка", "Загрузите фото и промпты", ""]], columns=["Промпт", "Вариант 1", "Вариант 2"]) | |
| table_data = [] | |
| for p in prompts: | |
| scores = predict_absolute(p, img1, img2) | |
| if scores: | |
| table_data.append([p, f"{scores[0]} / 100", f"{scores[1]} / 100"]) | |
| return pd.DataFrame(table_data, columns=["Промпт", "Вариант 1", "Вариант 2"]) | |
| # --- ИНТЕРФЕЙС GRADIO --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🐒 Punch the Monkey: Global UI Scorer") | |
| gr.Markdown("Оценивает качество UI/дизайна независимо для каждой картинки по шкале от 1 до 100.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Текстовые промпты для анализа") | |
| pr1 = gr.Textbox(label="Промпт 1", value="aesthetic mobile app ui") | |
| pr2 = gr.Textbox(label="Промпт 2", value="clean minimalist ui") | |
| pr3 = gr.Textbox(label="Промпт 3", value="professional mobile app interface") | |
| pr4 = gr.Textbox(label="Промпт 4", value="Apple Human Interface Guidelines") | |
| pr5 = gr.Textbox(label="Промпт 5", value="cozy simple watercolor icons") | |
| pr6 = gr.Textbox(label="Промпт 6", value="premium software design") | |
| pr7 = gr.Textbox(label="Промпт 7", value="high conversion screenshot") | |
| pr8 = gr.Textbox(label="Промпт 8", value="modern user experience") | |
| with gr.Row(): | |
| i1 = gr.Image(type="pil", label="Вариант 1") | |
| i2 = gr.Image(type="pil", label="Вариант 2") | |
| btn = gr.Button("🚀 Сделать глобальный расчет", variant="primary") | |
| with gr.Column(scale=1): | |
| out_table = gr.Dataframe( | |
| headers=["Промпт", "Вариант 1", "Вариант 2"], | |
| interactive=False | |
| ) | |
| btn.click( | |
| fn=multi_predict, | |
| inputs=[pr1, pr2, pr3, pr4, pr5, pr6, pr7, pr8, i1, i2], | |
| outputs=out_table | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |