|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import sys |
|
|
from PIL import Image |
|
|
import uuid |
|
|
import huggingface_hub |
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
|
|
from hpsv3.inference import HPSv3RewardInferencer |
|
|
try: |
|
|
import ImageReward as RM |
|
|
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer |
|
|
except: |
|
|
RM = None |
|
|
create_model_and_transforms = None |
|
|
get_tokenizer = None |
|
|
print("ImageReward or HPSv2 dependencies not found. Skipping those models.") |
|
|
|
|
|
from transformers import AutoProcessor, AutoModel |
|
|
|
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
DTYPE = torch.bfloat16 if DEVICE == 'cuda' else torch.float32 |
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
|
"HPSv3_7B": { |
|
|
"name": "HPSv3 7B", |
|
|
"type": "hpsv3" |
|
|
}, |
|
|
"HPSv2": { |
|
|
"name": "HPSv2", |
|
|
"checkpoint_path": "xswu/HPSv2/HPS_v2.1_compressed.pt", |
|
|
"type": "hpsv2" |
|
|
}, |
|
|
"ImageReward": { |
|
|
"name": "ImageReward v1.0", |
|
|
"checkpoint_path": "ImageReward-v1.0", |
|
|
"type": "imagereward" |
|
|
}, |
|
|
"PickScore": { |
|
|
"name": "PickScore", |
|
|
"checkpoint_path": "yuvalkirstain/PickScore_v1", |
|
|
"type": "pickscore" |
|
|
}, |
|
|
"CLIP": { |
|
|
"name": "CLIP ViT-H-14", |
|
|
"checkpoint_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", |
|
|
"type": "clip" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
current_models = {} |
|
|
current_model_name = None |
|
|
|
|
|
|
|
|
def load_model(model_key, update_status_fn=None): |
|
|
"""Load the specified model based on the model key.""" |
|
|
global current_models, current_model_name |
|
|
|
|
|
if model_key == current_model_name and model_key in current_models: |
|
|
return current_models[model_key] |
|
|
|
|
|
if update_status_fn: |
|
|
update_status_fn(f"π Loading {MODEL_CONFIGS[model_key]['name']}...") |
|
|
|
|
|
|
|
|
current_models.clear() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
config = MODEL_CONFIGS[model_key] |
|
|
|
|
|
try: |
|
|
if config["type"] == "hpsv3": |
|
|
checkpoint_path = huggingface_hub.hf_hub_download("MizzenAI/HPSv3", 'HPSv3.safetensors', repo_type='model') |
|
|
model = HPSv3RewardInferencer( |
|
|
device=DEVICE, |
|
|
checkpoint_path=checkpoint_path |
|
|
) |
|
|
elif config["type"] == "hpsv2": |
|
|
model_obj, preprocess_train, preprocess_val = create_model_and_transforms( |
|
|
'ViT-H-14', |
|
|
'laion2B-s32B-b79K', |
|
|
precision='amp', |
|
|
device=DEVICE, |
|
|
jit=False, |
|
|
force_quick_gelu=False, |
|
|
force_custom_text=False, |
|
|
force_patch_dropout=False, |
|
|
force_image_size=None, |
|
|
pretrained_image=False, |
|
|
image_mean=None, |
|
|
image_std=None, |
|
|
light_augmentation=True, |
|
|
aug_cfg={}, |
|
|
output_dict=True, |
|
|
with_score_predictor=False, |
|
|
with_region_predictor=False |
|
|
) |
|
|
checkpoint_path = huggingface_hub.hf_hub_download("xswu/HPSv2", 'HPS_v2.1_compressed.pt', repo_type='model') |
|
|
checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) |
|
|
model_obj.load_state_dict(checkpoint['state_dict']) |
|
|
model_obj = model_obj.to(DEVICE).eval() |
|
|
tokenizer = get_tokenizer('ViT-H-14') |
|
|
model = {"model": model_obj, "preprocess_val": preprocess_val, "tokenizer": tokenizer} |
|
|
elif config["type"] == "imagereward": |
|
|
model = RM.load(config["checkpoint_path"]) |
|
|
elif config["type"] == "pickscore": |
|
|
processor = AutoProcessor.from_pretrained('/preflab/models/CLIP-ViT-H-14-laion2B-s32B-b79K') |
|
|
model_obj = AutoModel.from_pretrained(config["checkpoint_path"]).eval().to(DEVICE) |
|
|
model = {"model": model_obj, "processor": processor} |
|
|
elif config["type"] == "clip": |
|
|
model_obj = AutoModel.from_pretrained(config["checkpoint_path"]).to(DEVICE) |
|
|
processor = AutoProcessor.from_pretrained(config["checkpoint_path"]) |
|
|
model = {"model": model_obj, "processor": processor} |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {config['type']}") |
|
|
|
|
|
current_models[model_key] = model |
|
|
current_model_name = model_key |
|
|
|
|
|
if update_status_fn: |
|
|
update_status_fn(f"β
{MODEL_CONFIGS[model_key]['name']} loaded successfully!") |
|
|
|
|
|
return model |
|
|
except Exception as e: |
|
|
error_msg = f"Error loading model {model_key}: {e}" |
|
|
print(error_msg) |
|
|
if update_status_fn: |
|
|
update_status_fn(f"β {error_msg}") |
|
|
return None |
|
|
|
|
|
def score_with_model(model_key, image_paths, prompts): |
|
|
"""Score images using the specified model.""" |
|
|
model = load_model(model_key) |
|
|
if model is None: |
|
|
raise ValueError(f"Failed to load model {model_key}") |
|
|
|
|
|
config = MODEL_CONFIGS[model_key] |
|
|
|
|
|
if config["type"] == "hpsv3": |
|
|
rewards = model.reward(image_paths, prompts) |
|
|
return [reward[0].item() for reward in rewards] |
|
|
elif config["type"] == "hpsv2": |
|
|
return score_hpsv2_batch(model, image_paths, prompts) |
|
|
elif config["type"] == "imagereward": |
|
|
return [model.score(prompt, image_path) for prompt, image_path in zip(prompts, image_paths)] |
|
|
elif config["type"] == "pickscore": |
|
|
return score_pickscore_batch(prompts, image_paths, model["model"], model["processor"]) |
|
|
elif config["type"] == "clip": |
|
|
return score_clip_batch(model["model"], model["processor"], image_paths, prompts) |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {config['type']}") |
|
|
|
|
|
def score_hpsv2_batch(model_dict, image_paths, prompts): |
|
|
"""Score using HPSv2 model.""" |
|
|
model = model_dict['model'] |
|
|
preprocess_val = model_dict['preprocess_val'] |
|
|
tokenizer = model_dict['tokenizer'] |
|
|
|
|
|
|
|
|
images = [preprocess_val(Image.open(p)).unsqueeze(0)[:,:3,:,:] for p in image_paths] |
|
|
images = torch.cat(images, dim=0).to(device=DEVICE) |
|
|
texts = tokenizer(prompts).to(device=DEVICE) |
|
|
with torch.no_grad(): |
|
|
outputs = model(images, texts) |
|
|
image_features, text_features = outputs["image_features"], outputs["text_features"] |
|
|
logits_per_image = image_features @ text_features.T |
|
|
hps_scores = torch.diagonal(logits_per_image).cpu() |
|
|
return [score.item() for score in hps_scores] |
|
|
|
|
|
def score_pickscore_batch(prompts, image_paths, model, processor): |
|
|
"""Score using PickScore model.""" |
|
|
pil_images = [Image.open(p) for p in image_paths] |
|
|
image_inputs = processor( |
|
|
images=pil_images, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
return_tensors="pt", |
|
|
).to(DEVICE) |
|
|
|
|
|
text_inputs = processor( |
|
|
text=prompts, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
return_tensors="pt", |
|
|
).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_embs = model.get_image_features(**image_inputs) |
|
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) |
|
|
text_embs = model.get_text_features(**text_inputs) |
|
|
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) |
|
|
scores = model.logit_scale.exp() * (text_embs @ image_embs.T) |
|
|
return [scores[i, i].cpu().item() for i in range(len(prompts))] |
|
|
|
|
|
def score_clip_batch(model, processor, image_paths, prompts): |
|
|
"""Score using CLIP model.""" |
|
|
pil_images = [Image.open(p) for p in image_paths] |
|
|
image_inputs = processor( |
|
|
images=pil_images, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
return_tensors="pt", |
|
|
).to(DEVICE) |
|
|
|
|
|
text_inputs = processor( |
|
|
text=prompts, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
return_tensors="pt", |
|
|
).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_embs = model.get_image_features(**image_inputs) |
|
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) |
|
|
text_embs = model.get_text_features(**text_inputs) |
|
|
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) |
|
|
scores = image_embs @ text_embs.T |
|
|
return [scores[i, i].cpu().item() for i in range(len(prompts))] |
|
|
|
|
|
|
|
|
print("Loading default HPSv3 model...") |
|
|
load_model("HPSv3_7B") |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
|
|
|
def get_score_interpretation(score): |
|
|
"""Returns a color-coded qualitative interpretation of the score.""" |
|
|
if score is None: |
|
|
return "" |
|
|
|
|
|
if score < 0: |
|
|
color = "#ef4444" |
|
|
bg_color = "rgba(239, 68, 68, 0.1)" |
|
|
icon = "β" |
|
|
feedback = "Poor Quality" |
|
|
comment = "The image has significant quality issues or doesn't match the prompt well." |
|
|
elif score < 5: |
|
|
color = "#f59e0b" |
|
|
bg_color = "rgba(245, 158, 11, 0.1)" |
|
|
icon = "β οΈ" |
|
|
feedback = "Needs Improvement" |
|
|
comment = "The image is acceptable but could be enhanced in quality or prompt alignment." |
|
|
elif score < 10: |
|
|
color = "#10b981" |
|
|
bg_color = "rgba(16, 185, 129, 0.1)" |
|
|
icon = "β
" |
|
|
feedback = "Good Quality" |
|
|
comment = "A well-crafted image that aligns nicely with the given prompt." |
|
|
else: |
|
|
color = "#06d6a0" |
|
|
bg_color = "rgba(6, 214, 160, 0.1)" |
|
|
icon = "β" |
|
|
feedback = "Excellent!" |
|
|
comment = "Outstanding quality and perfect alignment with the prompt." |
|
|
|
|
|
return f""" |
|
|
<div style=' |
|
|
background: {bg_color}; |
|
|
border: 2px solid {color}; |
|
|
border-radius: 16px; |
|
|
padding: 20px; |
|
|
text-align: center; |
|
|
margin: 10px 0; |
|
|
'> |
|
|
<div style='font-size: 2rem; margin-bottom: 8px;'>{icon}</div> |
|
|
<h3 style='color: {color}; font-size: 1.4rem; font-weight: 700; margin: 8px 0;'>{feedback}</h3> |
|
|
<p style='color: #666; font-size: 0.95rem; margin: 0; line-height: 1.4;'>{comment}</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
|
|
|
def handle_model_change(model_key): |
|
|
"""Handle model selection change.""" |
|
|
global current_model_name |
|
|
|
|
|
if model_key != current_model_name: |
|
|
|
|
|
yield f"π Loading {MODEL_CONFIGS[model_key]['name']}..." |
|
|
|
|
|
|
|
|
model = load_model(model_key) |
|
|
|
|
|
if model is not None: |
|
|
yield f"β
Current model: {MODEL_CONFIGS[model_key]['name']}" |
|
|
else: |
|
|
yield f"β Failed to load {MODEL_CONFIGS[model_key]['name']}" |
|
|
else: |
|
|
yield f"β
Current model: {MODEL_CONFIGS[model_key]['name']}" |
|
|
|
|
|
|
|
|
def predict_score(image, prompt, model_name): |
|
|
"""Takes Gradio inputs and returns the score, interpretation, and status.""" |
|
|
if image is None: |
|
|
return None, "", "β Error: Please upload an image." |
|
|
if not prompt or not prompt.strip(): |
|
|
return None, "", "β Error: Please enter a prompt." |
|
|
|
|
|
temp_dir = "temp_images_for_gradio" |
|
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
temp_path = os.path.join(temp_dir, f"{uuid.uuid4()}.png") |
|
|
|
|
|
try: |
|
|
Image.fromarray(image).save(temp_path) |
|
|
scores = score_with_model(model_name, [temp_path], [prompt]) |
|
|
score = round(scores[0], 4) |
|
|
interpretation = get_score_interpretation(score) |
|
|
return score, interpretation, "β
Analysis completed successfully!" |
|
|
except Exception as e: |
|
|
print(f"An error occurred during inference: {e}") |
|
|
return None, "", f"β Processing error: {e}" |
|
|
finally: |
|
|
if os.path.exists(temp_path): |
|
|
os.remove(temp_path) |
|
|
|
|
|
|
|
|
def compare_images(image1, image2, prompt, model_name): |
|
|
"""Compare two images and determine which one is better based on the prompt.""" |
|
|
if image1 is None or image2 is None: |
|
|
return None, None, "", "β Error: Please upload both images." |
|
|
if not prompt or not prompt.strip(): |
|
|
return None, None, "", "β Error: Please enter a prompt." |
|
|
|
|
|
temp_dir = "temp_images_for_gradio" |
|
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
temp_path1 = os.path.join(temp_dir, f"{uuid.uuid4()}_img1.png") |
|
|
temp_path2 = os.path.join(temp_dir, f"{uuid.uuid4()}_img2.png") |
|
|
|
|
|
try: |
|
|
Image.fromarray(image1).save(temp_path1) |
|
|
Image.fromarray(image2).save(temp_path2) |
|
|
|
|
|
|
|
|
scores = score_with_model(model_name, [temp_path1, temp_path2], [prompt, prompt]) |
|
|
score1 = round(scores[0], 4) |
|
|
score2 = round(scores[1], 4) |
|
|
|
|
|
|
|
|
if score1 > score2: |
|
|
winner_text = f"π **Image 1 is better!**\n\nImage 1 Score: **{score1}**\nImage 2 Score: **{score2}**\n\nDifference: **+{round(score1-score2, 4)}**" |
|
|
elif score2 > score1: |
|
|
winner_text = f"π **Image 2 is better!**\n\nImage 1 Score: **{score1}**\nImage 2 Score: **{score2}**\n\nDifference: **+{round(score2-score1, 4)}**" |
|
|
else: |
|
|
winner_text = f"π€ **It's a tie!**\n\nBoth images scored: **{score1}**" |
|
|
|
|
|
return score1, score2, winner_text, "β
Comparison completed successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"An error occurred during comparison: {e}") |
|
|
return None, None, "", f"β Processing error: {e}" |
|
|
finally: |
|
|
if os.path.exists(temp_path1): |
|
|
os.remove(temp_path1) |
|
|
if os.path.exists(temp_path2): |
|
|
os.remove(temp_path2) |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="HPSv3 - Human Preference Score v3") as demo: |
|
|
gr.HTML(f""" |
|
|
<div style="text-align: center; margin-bottom: 20px;"> |
|
|
<h1>π¨ HPSv3: Human Preference Score v3</h1> |
|
|
<p>Evaluate image quality and alignment with prompts with multiple models.</p> |
|
|
<p><a href="https://mizzenai.github.io/HPSv3.project/" target="_blank">π Project Website</a> | |
|
|
<a href="https://huggingface.co/papers/2508.03789" target="_blank">π Paper</a> | |
|
|
<a href="https://github.com/MizzenAI/HPSv3" target="_blank">π» Code</a></p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
model_selector = gr.Dropdown( |
|
|
choices=[(config["name"], key) for key, config in MODEL_CONFIGS.items()], |
|
|
value="HPSv3_7B", |
|
|
label="π€ Select Model", |
|
|
) |
|
|
model_status = gr.Textbox( |
|
|
label="Model Status", |
|
|
value=f"β
Current model: {MODEL_CONFIGS['HPSv3_7B']['name']}", |
|
|
interactive=False, |
|
|
scale=2 |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("π Image Scoring"): |
|
|
with gr.Row(equal_height=False): |
|
|
with gr.Column(scale=2): |
|
|
with gr.Group(): |
|
|
gr.Markdown("### πΌοΈ **Upload & Describe**") |
|
|
image_input = gr.Image( |
|
|
type="numpy", |
|
|
label="Upload Image", |
|
|
height=450 |
|
|
) |
|
|
prompt_input = gr.Textbox( |
|
|
label="Prompt Description", |
|
|
placeholder="Describe what the image should represent...", |
|
|
lines=3, |
|
|
max_lines=5 |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(): |
|
|
gr.Markdown("### π― **Quality Assessment**") |
|
|
score_output = gr.Number( |
|
|
label="Score", |
|
|
elem_id="score-output", |
|
|
precision=4 |
|
|
) |
|
|
interpretation_output = gr.Markdown(label="") |
|
|
status_output = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False |
|
|
) |
|
|
submit_button = gr.Button( |
|
|
"π Run Evaluation", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
submit_button.click( |
|
|
fn=predict_score, |
|
|
inputs=[image_input, prompt_input, model_selector], |
|
|
outputs=[score_output, interpretation_output, status_output] |
|
|
) |
|
|
|
|
|
with gr.Group(): |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["assets/example1.png", "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker, high resolution, vibrant colors"], |
|
|
["assets/example2.png", "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker, high resolution, vibrant colors"], |
|
|
], |
|
|
inputs=[image_input, prompt_input], |
|
|
outputs=[score_output, interpretation_output, status_output], |
|
|
fn=lambda img, prompt: predict_score(img, prompt, "HPSv3_7B"), |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("βοΈ Image Comparison"): |
|
|
with gr.Row(equal_height=False): |
|
|
with gr.Column(scale=2): |
|
|
with gr.Group(): |
|
|
gr.Markdown("### πΌοΈ **Upload Images & Prompt**") |
|
|
with gr.Row(): |
|
|
image1_input = gr.Image( |
|
|
type="numpy", |
|
|
label="Image 1", |
|
|
height=300 |
|
|
) |
|
|
image2_input = gr.Image( |
|
|
type="numpy", |
|
|
label="Image 2", |
|
|
height=300 |
|
|
) |
|
|
prompt_compare_input = gr.Textbox( |
|
|
label="Prompt Description", |
|
|
placeholder="Describe what the images should represent...", |
|
|
lines=3, |
|
|
max_lines=5 |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(): |
|
|
gr.Markdown("### π― **Comparison Results**") |
|
|
score1_output = gr.Number( |
|
|
label="Image 1 Score", |
|
|
precision=4 |
|
|
) |
|
|
score2_output = gr.Number( |
|
|
label="Image 2 Score", |
|
|
precision=4 |
|
|
) |
|
|
comparison_result = gr.Markdown(label="Winner") |
|
|
status_compare_output = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
compare_button = gr.Button( |
|
|
"βοΈ Compare Images", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
compare_button.click( |
|
|
fn=compare_images, |
|
|
inputs=[image1_input, image2_input, prompt_compare_input, model_selector], |
|
|
outputs=[score1_output, score2_output, comparison_result, status_compare_output] |
|
|
) |
|
|
|
|
|
with gr.Group(): |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["assets/example1.png", "assets/example2.png", "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker, high resolution, vibrant colors"], |
|
|
["assets/example2.png", "assets/example1.png", "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker, high resolution, vibrant colors"], |
|
|
], |
|
|
inputs=[image1_input, image2_input, prompt_compare_input], |
|
|
outputs=[score1_output, score2_output, comparison_result, status_compare_output], |
|
|
fn=lambda img1, img2, prompt: compare_images(img1, img2, prompt, "HPSv3_7B"), |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
|
|
|
model_selector.change( |
|
|
fn=handle_model_change, |
|
|
inputs=[model_selector], |
|
|
outputs=[model_status] |
|
|
) |
|
|
|
|
|
def main(): |
|
|
"""Main function to launch the demo.""" |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
favicon_path=None, |
|
|
show_error=True, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|