smc_meissonic / src /smc /rewards.py
cp524's picture
Fix device load race condition bug
b49998b
from PIL import Image
import torch
from importlib import resources
ASSETS_PATH = resources.files("assets")
def jpeg_compressibility(inference_dtype=None, device=None):
import io
import numpy as np
def loss_fn(images):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
buffers = [io.BytesIO() for _ in images]
for image, buffer in zip(images, buffers):
image.save(buffer, format="JPEG", quality=95)
sizes = [buffer.tell() / 1000 for buffer in buffers]
loss = torch.tensor(sizes, dtype=inference_dtype, device=device)
rewards = -1 * loss
return loss, rewards
return loss_fn
def clip_score(
inference_dtype=None,
device=None,
return_loss=False,
):
from src.smc.scorers.clip_scorer import CLIPScorer
scorer = CLIPScorer(dtype=torch.float32, device=device)
scorer.requires_grad_(False)
if not return_loss:
def _fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images, prompts)
return scores
return _fn
else:
def loss_fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images, prompts)
loss = - scores
return loss, scores
return loss_fn
def aesthetic_score(
torch_dtype=None,
aesthetic_target=None,
grad_scale=0,
device=None,
return_loss=False,
):
from src.smc.scorers.aesthetic_scorer import AestheticScorer
scorer = AestheticScorer(dtype=torch.float32, device=device)
scorer.requires_grad_(False)
if not return_loss:
def _fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images)
return scores
return _fn
else:
def loss_fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images)
if aesthetic_target is None: # default maximization
loss = -1 * scores
else:
# using L1 to keep on same scale
loss = abs(scores - aesthetic_target)
return loss * grad_scale, scores
return loss_fn
def hps_score(
inference_dtype=None,
device=None,
return_loss=False,
):
from src.smc.scorers.hpsv2_scorer import HPSv2Scorer
scorer = HPSv2Scorer(dtype=torch.float32, device=device)
scorer.requires_grad_(False)
if not return_loss:
def _fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images, prompts)
return scores
return _fn
else:
def loss_fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images, prompts)
loss = 1.0 - scores
return loss, scores
return loss_fn
def ImageReward(
inference_dtype=None,
device=None,
return_loss=False,
):
from src.smc.scorers.ImageReward_scorer import ImageRewardScorer
scorer = ImageRewardScorer(dtype=torch.float32, device=device)
scorer.requires_grad_(False)
if not return_loss:
def _fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images, prompts)
return scores
return _fn
else:
def loss_fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images, prompts)
loss = - scores
return loss, scores
return loss_fn
def ImageReward_Fk_Steering(
inference_dtype=None,
device=None,
device_load_lock=None,
return_loss=False,
bias=None,
):
from src.smc.scorers.image_reward_utils import rm_load
scorer = rm_load("ImageReward-v1.0", device=device, device_load_lock=device_load_lock)
if not return_loss:
def _fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer.score_batched(prompts, images)
if bias:
scores += bias
return scores
return _fn
else:
def loss_fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer.score_batched(prompts, images)
loss = - scores
return loss, scores
return loss_fn
def PickScore(
inference_dtype=None,
device=None,
return_loss=False,
):
from src.smc.scorers.PickScore_scorer import PickScoreScorer
scorer = PickScoreScorer(dtype=torch.float32, device=device)
scorer.requires_grad_(False)
if not return_loss:
def _fn(images, prompts):
# from src.plot_utils import save_batch_images
# save_batch_images(images, "output_SMC")
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images, prompts)
return scores
return _fn
else:
def loss_fn(images, prompts):
if images.min() < 0: # normalize unnormalized images
images = ((images / 2) + 0.5).clamp(0, 1)
scores = scorer(images, prompts)
loss = - scores
return loss, scores
return loss_fn
def color_match_reward(x: torch.Tensor, target_color: torch.Tensor) -> torch.Tensor:
"""
Reward images whose *mean* RGB comes close to a given target color.
Args:
x : [B, 3, H, W] float images (e.g. in [0,1] or [0,255])
target_color : [3] float tensor with your desired RGB mean
Returns:
reward : [B] higher when image mean-color ≈ target_color
"""
B, C, H, W = x.shape
# compute per-image mean color vector [B,3]
mean_color = x.view(B, C, -1).mean(dim=2)
# squared distance in RGB space
dist2 = (mean_color - target_color[None, :].to(x.device)).pow(2).sum(dim=1)
# negative distance = higher reward for closer color
return -dist2