vision / app.py
rafael1994s's picture
Update app.py
bfcf340 verified
import gradio as gr
import torch
import spaces
import pandas as pd
from PIL import Image
from transformers import AutoProcessor, AutoModel
# --- ИНИЦИАЛИЗАЦИЯ ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "yuvalkirstain/PickScore_v1"
processor_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
print(f"Loading models to {device}...")
processor = AutoProcessor.from_pretrained(processor_id)
model = AutoModel.from_pretrained(model_id).eval().to(device)
def ensure_tensor(output):
if isinstance(output, torch.Tensor): return output
if hasattr(output, 'image_embeds') and output.image_embeds is not None: return output.image_embeds
if hasattr(output, 'text_embeds') and output.text_embeds is not None: return output.text_embeds
if hasattr(output, 'pooler_output') and output.pooler_output is not None: return output.pooler_output
return output[0]
@spaces.GPU
def predict_absolute(prompt, image1, image2):
if image1 is None or image2 is None or not prompt.strip():
return None
try:
images = [image1, image2]
inputs = processor(text=[prompt], images=images, padding=True, return_tensors="pt").to(device)
with torch.no_grad():
img_output = model.get_image_features(pixel_values=inputs['pixel_values'])
txt_output = model.get_text_features(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
img_feats = ensure_tensor(img_output)
txt_feats = ensure_tensor(txt_output)
img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)
txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True)
scores = (txt_feats @ img_feats.T)[0] * model.logit_scale.exp()
results = []
for s in scores.cpu().tolist():
val = (s - 15) * 5
clamped_score = max(1, min(100, val))
results.append(round(clamped_score, 1))
return results
except Exception as e:
print(f"Internal Error: {e}")
return [0.0, 0.0]
def multi_predict(p1, p2, p3, p4, p5, p6, p7, p8, img1, img2):
# Собираем все 8 промптов в один список
raw_prompts = [p1, p2, p3, p4, p5, p6, p7, p8]
prompts = [p for p in raw_prompts if p and p.strip() != ""]
if not prompts or img1 is None or img2 is None:
return pd.DataFrame([["Ошибка", "Загрузите фото и промпты", ""]], columns=["Промпт", "Вариант 1", "Вариант 2"])
table_data = []
for p in prompts:
scores = predict_absolute(p, img1, img2)
if scores:
table_data.append([p, f"{scores[0]} / 100", f"{scores[1]} / 100"])
return pd.DataFrame(table_data, columns=["Промпт", "Вариант 1", "Вариант 2"])
# --- ИНТЕРФЕЙС GRADIO ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🐒 Punch the Monkey: Global UI Scorer")
gr.Markdown("Оценивает качество UI/дизайна независимо для каждой картинки по шкале от 1 до 100.")
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Текстовые промпты для анализа")
pr1 = gr.Textbox(label="Промпт 1", value="aesthetic mobile app ui")
pr2 = gr.Textbox(label="Промпт 2", value="clean minimalist ui")
pr3 = gr.Textbox(label="Промпт 3", value="professional mobile app interface")
pr4 = gr.Textbox(label="Промпт 4", value="Apple Human Interface Guidelines")
pr5 = gr.Textbox(label="Промпт 5", value="cozy simple watercolor icons")
pr6 = gr.Textbox(label="Промпт 6", value="premium software design")
pr7 = gr.Textbox(label="Промпт 7", value="high conversion screenshot")
pr8 = gr.Textbox(label="Промпт 8", value="modern user experience")
with gr.Row():
i1 = gr.Image(type="pil", label="Вариант 1")
i2 = gr.Image(type="pil", label="Вариант 2")
btn = gr.Button("🚀 Сделать глобальный расчет", variant="primary")
with gr.Column(scale=1):
out_table = gr.Dataframe(
headers=["Промпт", "Вариант 1", "Вариант 2"],
interactive=False
)
btn.click(
fn=multi_predict,
inputs=[pr1, pr2, pr3, pr4, pr5, pr6, pr7, pr8, i1, i2],
outputs=out_table
)
if __name__ == "__main__":
demo.launch()