Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import torch | |
| from importlib import resources | |
| ASSETS_PATH = resources.files("assets") | |
| def jpeg_compressibility(inference_dtype=None, device=None): | |
| import io | |
| import numpy as np | |
| def loss_fn(images): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| if isinstance(images, torch.Tensor): | |
| images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | |
| images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC | |
| images = [Image.fromarray(image) for image in images] | |
| buffers = [io.BytesIO() for _ in images] | |
| for image, buffer in zip(images, buffers): | |
| image.save(buffer, format="JPEG", quality=95) | |
| sizes = [buffer.tell() / 1000 for buffer in buffers] | |
| loss = torch.tensor(sizes, dtype=inference_dtype, device=device) | |
| rewards = -1 * loss | |
| return loss, rewards | |
| return loss_fn | |
| def clip_score( | |
| inference_dtype=None, | |
| device=None, | |
| return_loss=False, | |
| ): | |
| from src.smc.scorers.clip_scorer import CLIPScorer | |
| scorer = CLIPScorer(dtype=torch.float32, device=device) | |
| scorer.requires_grad_(False) | |
| if not return_loss: | |
| def _fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images, prompts) | |
| return scores | |
| return _fn | |
| else: | |
| def loss_fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images, prompts) | |
| loss = - scores | |
| return loss, scores | |
| return loss_fn | |
| def aesthetic_score( | |
| torch_dtype=None, | |
| aesthetic_target=None, | |
| grad_scale=0, | |
| device=None, | |
| return_loss=False, | |
| ): | |
| from src.smc.scorers.aesthetic_scorer import AestheticScorer | |
| scorer = AestheticScorer(dtype=torch.float32, device=device) | |
| scorer.requires_grad_(False) | |
| if not return_loss: | |
| def _fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images) | |
| return scores | |
| return _fn | |
| else: | |
| def loss_fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images) | |
| if aesthetic_target is None: # default maximization | |
| loss = -1 * scores | |
| else: | |
| # using L1 to keep on same scale | |
| loss = abs(scores - aesthetic_target) | |
| return loss * grad_scale, scores | |
| return loss_fn | |
| def hps_score( | |
| inference_dtype=None, | |
| device=None, | |
| return_loss=False, | |
| ): | |
| from src.smc.scorers.hpsv2_scorer import HPSv2Scorer | |
| scorer = HPSv2Scorer(dtype=torch.float32, device=device) | |
| scorer.requires_grad_(False) | |
| if not return_loss: | |
| def _fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images, prompts) | |
| return scores | |
| return _fn | |
| else: | |
| def loss_fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images, prompts) | |
| loss = 1.0 - scores | |
| return loss, scores | |
| return loss_fn | |
| def ImageReward( | |
| inference_dtype=None, | |
| device=None, | |
| return_loss=False, | |
| ): | |
| from src.smc.scorers.ImageReward_scorer import ImageRewardScorer | |
| scorer = ImageRewardScorer(dtype=torch.float32, device=device) | |
| scorer.requires_grad_(False) | |
| if not return_loss: | |
| def _fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images, prompts) | |
| return scores | |
| return _fn | |
| else: | |
| def loss_fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images, prompts) | |
| loss = - scores | |
| return loss, scores | |
| return loss_fn | |
| def ImageReward_Fk_Steering( | |
| inference_dtype=None, | |
| device=None, | |
| device_load_lock=None, | |
| return_loss=False, | |
| bias=None, | |
| ): | |
| from src.smc.scorers.image_reward_utils import rm_load | |
| scorer = rm_load("ImageReward-v1.0", device=device, device_load_lock=device_load_lock) | |
| if not return_loss: | |
| def _fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer.score_batched(prompts, images) | |
| if bias: | |
| scores += bias | |
| return scores | |
| return _fn | |
| else: | |
| def loss_fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer.score_batched(prompts, images) | |
| loss = - scores | |
| return loss, scores | |
| return loss_fn | |
| def PickScore( | |
| inference_dtype=None, | |
| device=None, | |
| return_loss=False, | |
| ): | |
| from src.smc.scorers.PickScore_scorer import PickScoreScorer | |
| scorer = PickScoreScorer(dtype=torch.float32, device=device) | |
| scorer.requires_grad_(False) | |
| if not return_loss: | |
| def _fn(images, prompts): | |
| # from src.plot_utils import save_batch_images | |
| # save_batch_images(images, "output_SMC") | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images, prompts) | |
| return scores | |
| return _fn | |
| else: | |
| def loss_fn(images, prompts): | |
| if images.min() < 0: # normalize unnormalized images | |
| images = ((images / 2) + 0.5).clamp(0, 1) | |
| scores = scorer(images, prompts) | |
| loss = - scores | |
| return loss, scores | |
| return loss_fn | |
| def color_match_reward(x: torch.Tensor, target_color: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Reward images whose *mean* RGB comes close to a given target color. | |
| Args: | |
| x : [B, 3, H, W] float images (e.g. in [0,1] or [0,255]) | |
| target_color : [3] float tensor with your desired RGB mean | |
| Returns: | |
| reward : [B] higher when image mean-color ≈ target_color | |
| """ | |
| B, C, H, W = x.shape | |
| # compute per-image mean color vector [B,3] | |
| mean_color = x.view(B, C, -1).mean(dim=2) | |
| # squared distance in RGB space | |
| dist2 = (mean_color - target_color[None, :].to(x.device)).pow(2).sum(dim=1) | |
| # negative distance = higher reward for closer color | |
| return -dist2 | |