Adv-GRPO_DINO / adv_grpo /rewards.py
benzweijia's picture
Upload 61 files
9294bc7 verified
from PIL import Image
import io
import numpy as np
import torch
from collections import defaultdict
import torch
import timm
import torch.nn as nn
import torch.nn.functional as F
def jpeg_incompressibility():
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
buffers = [io.BytesIO() for _ in images]
for image, buffer in zip(images, buffers):
image.save(buffer, format="JPEG", quality=95)
sizes = [buffer.tell() / 1000 for buffer in buffers]
return np.array(sizes), {}
return _fn
def jpeg_compressibility():
jpeg_fn = jpeg_incompressibility()
def _fn(images, prompts, metadata):
rew, meta = jpeg_fn(images, prompts, metadata)
return -rew/500, meta
return _fn
def aesthetic_score():
from adv_grpo.aesthetic_scorer import AestheticScorer
scorer = AestheticScorer(dtype=torch.float32).cuda()
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
else:
images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
images = torch.tensor(images, dtype=torch.uint8)
scores = scorer(images)
return scores, {}
return _fn
def clip_score():
from adv_grpo.clip_scorer import ClipScorer
# scorer = ClipScorer(dtype=torch.float32).cuda()
scorer = ClipScorer().cuda()
def _fn(images, prompts, metadata):
if not isinstance(images, torch.Tensor):
images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
images = torch.tensor(images, dtype=torch.uint8)/255.0
scores = scorer(images, prompts)
return scores, {}
return _fn
def siglip_image_similarity_score(device):
import torch
import numpy as np
from transformers import SiglipModel
import torch.nn.functional as F
# 1. 加载 SigLIP 模型(推荐 so400m-p14-384)
scorer = SiglipModel.from_pretrained(
"google/siglip-so400m-patch14-384"
).to(device).to(torch.bfloat16)
scorer.eval()
# SigLIP preprocess mean/std
siglip_mean = [0.5, 0.5, 0.5]
siglip_std = [0.5, 0.5, 0.5]
# 模型输入分辨率(自动匹配所选 SigLIP 模型)
image_size = scorer.config.vision_config.image_size # e.g., 224/256/384
def _preprocess(images):
# 转成 tensor
if not isinstance(images, torch.Tensor):
images = torch.tensor(images, dtype=torch.float32)
# 0-255 → 0-1
if images.max() > 1.0:
images = images / 255.0
# NHWC → NCHW
if images.shape[-1] == 3:
images = images.permute(0, 3, 1, 2)
# resize to SigLIP input
images = F.interpolate(
images,
size=(image_size, image_size),
mode="bicubic",
align_corners=False
)
# normalize using SigLIP's mean/std
mean = torch.tensor(siglip_mean, device=device)[None, :, None, None]
std = torch.tensor(siglip_std, device=device)[None, :, None, None]
images = (images - mean) / std
return images.to(device).to(torch.bfloat16)
def _fn(images, ref_images):
# 2. preprocess
images = _preprocess(images)
ref_images = _preprocess(ref_images)
with torch.no_grad():
# SigLIP extract feature
out_img = scorer.vision_model(
pixel_values=images.to(torch.float32)
)
out_ref = scorer.vision_model(
pixel_values=ref_images.to(torch.float32)
)
emb_images = out_img.pooler_output.to(torch.bfloat16)
emb_ref = out_ref.pooler_output.to(torch.bfloat16)
# 3. normalize embeddings (L2)
emb_images = emb_images / emb_images.norm(dim=-1, keepdim=True)
emb_ref = emb_ref / emb_ref.norm(dim=-1, keepdim=True)
# 4. cosine similarity
scores = torch.matmul(emb_images, emb_ref.T) # [N,M]
per_img = scores.max(dim=1).values # [N]
return per_img.detach(), {"pairwise": scores.detach()}
return _fn
def image_similarity_score(device):
import torch
import numpy as np
import timm
import torch.nn.functional as F
# 1. 加载 DINOv2 模型(这里用 ViT-Base/14,可换成 ViT-L/14 或 ViT-G/14)
model = timm.create_model("vit_base_patch14_dinov2.lvd142m", pretrained=True)
# model = timm.create_model('vit_large_patch16_dinov3_qkvb.lvd1689m', pretrained=True)
model.eval().to(device)
def _preprocess(images):
if not isinstance(images, torch.Tensor):
images = torch.tensor(images, dtype=torch.float32)
if images.max() > 1.0:
images = images / 255.0
if images.shape[-1] == 3: # NHWC -> NCHW
images = images.permute(0, 3, 1, 2)
# 调整到 518×518
# images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
# DINOv2 normalization
# images = (images - 0.5) / 0.5
images = torch.nn.functional.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
images = (images - mean) / std
return images.to(device)
def _fn(images, ref_images):
# 2. 预处理
# import pdb; pdb.set_trace()
images = _preprocess(images)
ref_images = _preprocess(ref_images)
with torch.no_grad():
# import pdb; pdb.set_trace()
emb_images = model(images) # [N,D]
emb_ref = model(ref_images) # [M,D]
# 3. 归一化
emb_images = emb_images / emb_images.norm(dim=-1, keepdim=True)
emb_ref = emb_ref / emb_ref.norm(dim=-1, keepdim=True)
# 4. 计算相似度 (余弦相似度)
scores = torch.matmul(emb_images, emb_ref.T) # [N,M]
per_img = scores.max(dim=1).values # [N];若想平均用:scores.mean(dim=1)
# import pdb; pdb.set_trace()
# 返回一维分数,pairwise 放到 info 里
# return per_img.detach(), {"pairwise": scores.detach()}, emb_images, emb_ref
return per_img.detach(), {"pairwise": scores.detach()}
# return scores, {}
return _fn
def image_similarity_score_eval(device):
import torch
import numpy as np
import timm
import torch.nn.functional as F
# 1. 加载 DINOv2 模型(这里用 ViT-Base/14,可换成 ViT-L/14 或 ViT-G/14)
model = timm.create_model("vit_base_patch14_dinov2.lvd142m", pretrained=True)
model.eval().to(device)
def _preprocess(images):
if not isinstance(images, torch.Tensor):
images = torch.tensor(images, dtype=torch.float32)
if images.max() > 1.0:
images = images / 255.0
if images.shape[-1] == 3: # NHWC -> NCHW
images = images.permute(0, 3, 1, 2)
# 调整到 518×518
# images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
# DINOv2 normalization
# images = (images - 0.5) / 0.5
images = torch.nn.functional.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
images = (images - mean) / std
return images.to(device)
def _fn(images, ref_images):
# 2. 预处理
# import pdb; pdb.set_trace()
images = _preprocess(images)
ref_images = _preprocess(ref_images)
with torch.no_grad():
# import pdb; pdb.set_trace()
emb_images = model(images) # [N,D]
emb_ref = model(ref_images) # [M,D]
# 3. 归一化
# emb_images = emb_images / emb_images.norm(dim=-1, keepdim=True)
# emb_ref = emb_ref / emb_ref.norm(dim=-1, keepdim=True)
# 4. 计算相似度 (余弦相似度)
scores = torch.matmul(emb_images, emb_ref.T) # [N,M]
per_img = scores.max(dim=1).values # [N];若想平均用:scores.mean(dim=1)
# import pdb; pdb.set_trace()
# 返回一维分数,pairwise 放到 info 里
return per_img.detach(), {"pairwise": scores.detach()}, emb_images, emb_ref
# return per_img.detach(), {"pairwise": scores.detach()}
# return scores, {}
return _fn
def dino_cotrain_score(device):
def _preprocess(images):
if not isinstance(images, torch.Tensor):
images = torch.tensor(images, dtype=torch.float32)
if images.max() > 1.0:
images = images / 255.0
if images.shape[-1] == 3: # NHWC -> NCHW
images = images.permute(0, 3, 1, 2)
images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
# images = (images - 0.5) / 0.5
mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
images = (images - mean) / std
return images.to(device).to(torch.bfloat16)
def _fn(scorer, head, images, prompts, metadata):
images = _preprocess(images)
with torch.no_grad():
emb = scorer(images) # [N,D]
emb = emb / emb.norm(dim=-1, keepdim=True)
# Head 输出 reward
scores = head(emb).squeeze(-1) # [N]
# import pdb; pdb.set_trace()
return scores.detach(), {"embeddings": emb.detach()}
return _fn
def siglip_cotrain_score(device):
"""
使用方式与原来的 dino_cotrain_score 一致:
reward_fn = siglip_cotrain_score(device)
reward_fn(scorer, head, images, prompts, metadata)
scorer: SigLIPModel (从 HF transformers 加载)
head: 你的 reward head
"""
from torchvision import transforms
tiny_jitter = transforms.ColorJitter(
brightness=0.02,
contrast=0.02,
)
def _preprocess(images, image_size=224):
# 转 tensor
if not isinstance(images, torch.Tensor):
images = torch.tensor(images, dtype=torch.float32)
# 将 uint8 0-255 → 0-1
if images.max() > 1.0:
images = images / 255.0
# NHWC → NCHW
if images.shape[-1] == 3:
images = images.permute(0, 3, 1, 2)
imgs_aug = []
for img in images:
imgs_aug.append(tiny_jitter(img)) # 做轻微亮度扰动
images = torch.stack(imgs_aug)
# Resize到 SigLIP 默认输入大小(通常 224 / 256 / 384)
images = F.interpolate(
images,
size=(image_size, image_size),
mode="bicubic",
align_corners=False
)
# ★★★ SigLIP 官方 mean/std(注意不同于 CLIP)★★★
mean = torch.tensor([0.5, 0.5, 0.5], device=device)[None, :, None, None]
std = torch.tensor([0.5, 0.5, 0.5], device=device)[None, :, None, None]
images = (images - mean) / std
return images.to(device).to(torch.bfloat16)
def _fn(scorer, head, images, prompts, metadata):
"""
scorer: SigLIPModel
head: reward head
"""
# 图像预处理(scorer.config.vision_config.image_size 通常 224)
# import pdb; pdb.set_trace()
images = _preprocess(images, 512)
scorer.eval()
with torch.no_grad():
# ★★★ SigLIP 特征获取 ★★★
# 返回 CLS token 的 embedding,类似 CLIP 的 global feature
vision_out = scorer.vision_model(
pixel_values=images.to(torch.float32)
)
emb = vision_out.pooler_output.to(torch.bfloat16) # [B, D]
# head 输出 reward
scores = head(emb).squeeze(-1)
return scores.detach(), {"embeddings": emb.detach()}
return _fn
def dino_patch_cotrain_score(device, n_patches=64):
import torch
import torch.nn.functional as F
def _preprocess(images):
if not isinstance(images, torch.Tensor):
images = torch.tensor(images, dtype=torch.float32)
if images.max() > 1.0:
images = images / 255.0
if images.shape[-1] == 3: # NHWC -> NCHW
images = images.permute(0, 3, 1, 2)
images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
# images = (images - 0.5) / 0.5
mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
images = (images - mean) / std
return images.to(device).to(torch.bfloat16)
def _fn(scorer, head, images, prompts, metadata, cls_weight=0.7):
images = _preprocess(images)
with torch.no_grad():
# 提取所有特征: [B, N+1, D]
feats = scorer.forward_features(images)
# --- 分离 CLS 与 patch ---
cls_emb = feats[:, 0, :] # [B, D]
patch_emb = feats[:, 1:, :] # [B, N, D]
B, N, D = patch_emb.shape
# --- 随机采样 patch ---
n_select = min(n_patches, N)
idx = torch.randint(0, N, (B, n_select), device=device)
sampled_patches = torch.gather(
patch_emb, 1, idx.unsqueeze(-1).expand(-1, -1, D)
) # [B, n_select, D]
# --- 归一化 ---
cls_emb = cls_emb / (cls_emb.norm(dim=-1, keepdim=True) + 1e-6)
sampled_patches = sampled_patches / (sampled_patches.norm(dim=-1, keepdim=True) + 1e-6)
# --- 计算分数 ---
cls_score = head(cls_emb).squeeze(-1) # [B]
patch_scores = head(sampled_patches).squeeze(-1) # [B, n_select]
patch_score_mean = patch_scores.mean(dim=1) # [B]
# --- 混合 reward ---
hybrid_score = cls_weight * cls_score + (1 - cls_weight) * patch_score_mean
# hybrid_score = hybrid_score.unsqueeze(-1) # [B, 1]
# import pdb; pdb.set_trace()
# --- 返回结果 ---
return hybrid_score.detach(), {
"cls_score": cls_score.detach(),
"patch_scores": patch_scores.detach(),
"patch_indices": idx.detach(),
"cls_weight": cls_weight,
}
return _fn
def _get_layer_tokens_timm(model, imgs, layer_ids=(2, 5, 8, 11)):
handles, feats = [], {i: None for i in layer_ids}
def make_hook(i):
def hook(_module, _inp, out):
# 对 timm 的 ViT,block 的输出通常是 [B, N+1, D]
feats[i] = out
return hook
for i in layer_ids:
assert 0 <= i < len(model.blocks), f"layer id {i} out of range"
handles.append(model.blocks[i].register_forward_hook(make_hook(i)))
# 触发前向。timm 的 ViT 有 forward_features;没有就直接 __call__
with torch.no_grad():
if hasattr(model, "forward_features"):
_ = model.forward_features(imgs)
else:
_ = model(imgs)
for h in handles:
h.remove()
return [feats[i] for i in layer_ids] # list of [B, N+1, D]
# -------- reward 工厂:分层 head + top-k 池化 + 融合 --------
def dino_multi_cotrain_score(
device,
topk_tau=0.2, # 每层取前 tau 比例的 patch logits 做均值
apply_sigmoid=True, # 是否把 logit 过 sigmoid 得到 [0,1] reward
lambda_cls=0.5,
zscore=False, # 是否对 batch 维做 z-score(组内可再自行处理)
):
def _preprocess(images):
if not isinstance(images, torch.Tensor):
images = torch.tensor(images, dtype=torch.float32)
if images.max() > 1.0:
images = images / 255.0
if images.shape[-1] == 3: # NHWC -> NCHW
images = images.permute(0, 3, 1, 2)
images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
images = (images - 0.5) / 0.5
# mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
# std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
# images = (images - mean) / std
return images.to(device)
@torch.no_grad()
def _fn(scorer, heads, fusion, images,prompts=None, metadata=None, layer_ids=(8,),temperature=0.2):
"""
scorer : timm 的 ViT-DINOv2 backbone(已 .eval() 且 requires_grad=False)
heads : nn.ModuleList,长度 = len(layer_ids),每层一个 head
fusion : 融合器,将 (B,T) -> (B,)
images : [N,H,W,3] or [N,3,H,W],值域 [0,1]/[0,255]
返回:
rewards: [N](float32)
aux: dict,包含 per_layer_scores 等,便于调试/可视化
"""
from torch.nn.parallel import DistributedDataParallel as DDP
hmod = heads.module if isinstance(heads, DDP) else heads
fmod = fusion.module if isinstance(fusion, DDP) else fusion
x = _preprocess(images).to(dtype=next(scorer.parameters()).dtype, device=device)
# 取多层 tokens
tokens_list = _get_layer_tokens_timm(scorer, x, layer_ids=layer_ids) # list of [B, N+1, D]
B = x.size(0)
T = len(tokens_list)
per_layer_scores = []
per_layer_logits = [] # 保存每层的 top-k 前的 patch logits(可选)
per_layer_cls_scores = []
for t in range(T):
tokens = tokens_list[t] # [B, N+1, D]
patch = tokens[:, 1:] # [B, N, D] 忽略 CLS
class_patch = tokens[:, 0]
Bn, N, D = patch.shape
# head 支持 [B,N,D]:输出 [B,N]
logits_patch = hmod[t](patch).squeeze(-1) # [B, N]
per_layer_logits.append(logits_patch)
# 层内 top-k 池化
k = max(1, int(N * topk_tau))
pooled = logits_patch.topk(k, dim=1).values.mean(dim=1) # [B]
per_layer_scores.append(pooled)
cls_logit = hmod[t](class_patch).squeeze(-1) # [B]
per_layer_cls_scores.append(cls_logit)
per_layer_scores = torch.stack(per_layer_scores, dim=1) # [B, T]
per_layer_cls_scores = torch.stack(per_layer_cls_scores, dim=1)
# 融合为最终 logit
logit_patch = fmod(per_layer_scores) # [B]
logit_cls = fmod(per_layer_cls_scores)
# logits = (1.0 - float(lambda_cls)) * logit_patch + float(lambda_cls) * logit_cls # [B]
logits = logit_patch
# 标定成 reward
rewards = logits
# import pdb; pdb.set_trace()
if apply_sigmoid:
rewards = torch.sigmoid(rewards / float(temperature))
if zscore:
mu = rewards.mean(dim=0, keepdim=True)
sigma = rewards.std(dim=0, keepdim=True).clamp_min(1e-6)
rewards = (rewards - mu) / sigma
# 输出 float32,避免后续与 bf16 混用出问题
rewards = rewards.float()
# import pdb; pdb.set_trace()
aux = {
"per_layer_scores": per_layer_scores.float(), # [B,T]
"logits": logits.float(), # 未标定的融合 logit
# 下行可能很大(B×T×N),需要时再用;默认不返回节省带宽
# "per_layer_logits": [lp.float() for lp in per_layer_logits],
}
return rewards, aux
return _fn
def pickscore_score(device):
from adv_grpo.pickscore_scorer import PickScoreScorer
scorer = PickScoreScorer(dtype=torch.float32, device=device)
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
scores = scorer(prompts, images)
return scores, {}
return _fn
def pickscore_cotrain_score(device):
def _fn(scorer, images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
scores = scorer(prompts, images)
# import pdb; pdb.set_trace()
return scores, {}
return _fn
def pickscore_score_patch(device):
from adv_grpo.pickscore_scorer_patch import PickScoreScorer
scorer = PickScoreScorer(dtype=torch.float32, device=device)
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
scores = scorer(prompts, images)
return scores, {}
return _fn
def discriminator_score(device):
def _fn(scorer, images, prompts=None, metadata=None):
# 归一化到 [-1,1]
if isinstance(images, torch.Tensor):
if images.max() > 1.5: # 可能是 0~255
images = images / 255.0
images = (images - 0.5) * 2.0
else:
raise ValueError("images must be a torch.Tensor in [B,3,H,W]")
with torch.no_grad():
logits = scorer(images.to(device)) # StyleGAN: [B] / [B,1] PatchGAN: [B,1,H',W']
if logits.ndim == 1:
# StyleGAN D,已经是 [B]
scores = torch.sigmoid(logits)
elif logits.ndim == 2 and logits.shape[1] == 1:
# StyleGAN D,输出 [B,1]
scores = torch.sigmoid(logits.squeeze(1))
elif logits.ndim == 4 and logits.shape[1] == 1:
# PatchGAN D,输出 [B,1,H',W']
scores = torch.sigmoid(logits).mean(dim=[1,2,3]) # -> [B]
else:
raise ValueError(f"Unexpected logits shape: {logits.shape}")
return scores.cpu(), {}
return _fn
def imagereward_score(device):
from adv_grpo.imagereward_scorer import ImageRewardScorer
scorer = ImageRewardScorer(dtype=torch.float32, device=device)
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
prompts = [prompt for prompt in prompts]
scores = scorer(prompts, images)
return scores, {}
return _fn
def qwenvl_score(device):
from adv_grpo.qwenvl import QwenVLScorer
scorer = QwenVLScorer(dtype=torch.bfloat16, device=device)
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
prompts = [prompt for prompt in prompts]
scores = scorer(prompts, images)
return scores, {}
return _fn
def ocr_score(device):
from adv_grpo.ocr import OcrScorer
scorer = OcrScorer()
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
# import pdb; pdb.set_trace()
scores = scorer(images, prompts)
# change tensor to list
return scores, {}
return _fn
def video_ocr_score(device):
from adv_grpo.ocr import OcrScorer_video_or_image
scorer = OcrScorer_video_or_image()
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
if images.dim() == 4 and images.shape[1] == 3:
images = images.permute(0, 2, 3, 1)
elif images.dim() == 5 and images.shape[2] == 3:
images = images.permute(0, 1, 3, 4, 2)
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
scores = scorer(images, prompts)
# change tensor to list
return scores, {}
return _fn
def constractive_external(device, beta=0.5, top_n=2):
import torch
from PIL import Image
from adv_grpo.pickscore_scorer_constractive import PickScoreScorerConstractive
scorer = PickScoreScorerConstractive(dtype=torch.float32, device=device)
def _fn(images,ref_images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
ref_images = (ref_images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
ref_images = ref_images.transpose(0, 2, 3, 1) # NCHW -> NHWC
ref_images = [Image.fromarray(image) for image in ref_images]
# import pdb; pdb.set_trace()
# scorer
scores, ref_scores, img_embeds, ref_img_embeds = scorer(prompts, images, ref_images)
# external anchor
ref_embed = ref_img_embeds.mean(dim=0, keepdim=True) # (1,D)
ext_score = ref_scores.mean()
top_idx = torch.topk(scores, k=min(top_n, len(scores))).indices
hack_scores = scores[top_idx]
hack_embeds = img_embeds[top_idx] # (N,D)
if ext_score >= hack_scores.max():
return scores, {"raw_scores": scores, "ref_scores": ref_scores}
# 计算对比修正
sim_to_ext = torch.nn.functional.cosine_similarity(img_embeds, ref_embed)
sim_to_hack = torch.nn.functional.cosine_similarity(
img_embeds.unsqueeze(1), hack_embeds.unsqueeze(0), dim=-1
) # (num_images, N)
sim_to_hack = sim_to_hack.mean(dim=1)
adjusted_scores = scores + beta * (sim_to_ext - sim_to_hack)
return adjusted_scores, {
"raw_scores": scores,
"ref_scores": ref_scores,
"sim_to_ext": sim_to_ext,
"sim_to_hack": sim_to_hack,
"hack_scores": hack_scores
}
return _fn
def deqa_score_remote(device):
"""Submits images to DeQA and computes a reward.
"""
import requests
from requests.adapters import HTTPAdapter, Retry
from io import BytesIO
import pickle
batch_size = 64
url = "http://127.0.0.1:18086"
sess = requests.Session()
retries = Retry(
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
)
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
del prompts
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
all_scores = []
for image_batch in images_batched:
jpeg_images = []
# Compress the images using JPEG
for image in image_batch:
img = Image.fromarray(image)
buffer = BytesIO()
img.save(buffer, format="JPEG")
jpeg_images.append(buffer.getvalue())
# format for LLaVA server
data = {
"images": jpeg_images,
}
data_bytes = pickle.dumps(data)
# send a request to the llava server
response = sess.post(url, data=data_bytes, timeout=120)
response_data = pickle.loads(response.content)
all_scores += response_data["outputs"]
return all_scores, {}
return _fn
def geneval_score(device):
"""Submits images to GenEval and computes a reward.
"""
import requests
from requests.adapters import HTTPAdapter, Retry
from io import BytesIO
import pickle
batch_size = 64
url = "http://127.0.0.1:18085"
sess = requests.Session()
retries = Retry(
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
)
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadatas, only_strict):
del prompts
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
metadatas_batched = np.array_split(metadatas, np.ceil(len(metadatas) / batch_size))
all_scores = []
all_rewards = []
all_strict_rewards = []
all_group_strict_rewards = []
all_group_rewards = []
for image_batch, metadata_batched in zip(images_batched, metadatas_batched):
jpeg_images = []
# Compress the images using JPEG
for image in image_batch:
img = Image.fromarray(image)
buffer = BytesIO()
img.save(buffer, format="JPEG")
jpeg_images.append(buffer.getvalue())
# format for LLaVA server
data = {
"images": jpeg_images,
"meta_datas": list(metadata_batched),
"only_strict": only_strict,
}
data_bytes = pickle.dumps(data)
# send a request to the llava server
response = sess.post(url, data=data_bytes, timeout=120)
response_data = pickle.loads(response.content)
all_scores += response_data["scores"]
all_rewards += response_data["rewards"]
all_strict_rewards += response_data["strict_rewards"]
all_group_strict_rewards.append(response_data["group_strict_rewards"])
all_group_rewards.append(response_data["group_rewards"])
all_group_strict_rewards_dict = defaultdict(list)
all_group_rewards_dict = defaultdict(list)
for current_dict in all_group_strict_rewards:
for key, value in current_dict.items():
all_group_strict_rewards_dict[key].extend(value)
all_group_strict_rewards_dict = dict(all_group_strict_rewards_dict)
for current_dict in all_group_rewards:
for key, value in current_dict.items():
all_group_rewards_dict[key].extend(value)
all_group_rewards_dict = dict(all_group_rewards_dict)
return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict
return _fn
def unifiedreward_score_remote(device):
"""Submits images to DeQA and computes a reward.
"""
import requests
from requests.adapters import HTTPAdapter, Retry
from io import BytesIO
import pickle
batch_size = 64
url = "http://10.82.120.15:18085"
sess = requests.Session()
retries = Retry(
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
)
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
all_scores = []
for image_batch, prompt_batch in zip(images_batched, prompts_batched):
jpeg_images = []
# Compress the images using JPEG
for image in image_batch:
img = Image.fromarray(image)
buffer = BytesIO()
img.save(buffer, format="JPEG")
jpeg_images.append(buffer.getvalue())
# format for LLaVA server
data = {
"images": jpeg_images,
"prompts": prompt_batch
}
data_bytes = pickle.dumps(data)
# send a request to the llava server
response = sess.post(url, data=data_bytes, timeout=120)
print("response: ", response)
print("response: ", response.content)
response_data = pickle.loads(response.content)
all_scores += response_data["outputs"]
return all_scores, {}
return _fn
def unifiedreward_score_sglang(device):
import asyncio
from openai import AsyncOpenAI
import base64
from io import BytesIO
import re
def pil_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8")
base64_qwen = f"data:image;base64,{encoded_image_text}"
return base64_qwen
def _extract_scores(text_outputs):
scores = []
pattern = r"Final Score:\s*([1-5](?:\.\d+)?)"
for text in text_outputs:
match = re.search(pattern, text)
if match:
try:
scores.append(float(match.group(1)))
except ValueError:
scores.append(0.0)
else:
scores.append(0.0)
return scores
client = AsyncOpenAI(base_url="http://127.0.0.1:17140/v1", api_key="flowgrpo")
async def evaluate_image(prompt, image):
question = f"<image>\nYou are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\nBased on the above criteria, assign a score from 1 to 5 after \'Final Score:\'.\nYour task is provided as follows:\nText Caption: [{prompt}]"
images_base64 = pil_image_to_base64(image)
response = await client.chat.completions.create(
model="UnifiedReward-7b-v1.5",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": images_base64},
},
{
"type": "text",
"text": question,
},
],
},
],
temperature=0,
)
return response.choices[0].message.content
async def evaluate_batch_image(images, prompts):
tasks = [evaluate_image(prompt, img) for prompt, img in zip(prompts, images)]
results = await asyncio.gather(*tasks)
return results
def _fn(images, prompts, metadata):
# 处理Tensor类型转换
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
# 转换为PIL Image并调整尺寸
images = [Image.fromarray(image).resize((512, 512)) for image in images]
# 执行异步批量评估
text_outputs = asyncio.run(evaluate_batch_image(images, prompts))
score = _extract_scores(text_outputs)
score = [sc/5.0 for sc in score]
return score, {}
return _fn
def multi_score(device, score_dict):
score_functions = {
"deqa": deqa_score_remote,
"ocr": ocr_score,
"video_ocr": video_ocr_score,
"imagereward": imagereward_score,
"pickscore": pickscore_score,
"qwenvl": qwenvl_score,
"aesthetic": aesthetic_score,
"jpeg_compressibility": jpeg_compressibility,
"unifiedreward": unifiedreward_score_sglang,
"geneval": geneval_score,
"clipscore": clip_score,
"image_similarity": image_similarity_score,
"image_similarity_eval": image_similarity_score_eval,
"constractive_external": constractive_external,
"discriminator": discriminator_score,
"pickscore_cotrain": pickscore_cotrain_score,
"pickscore_patch":pickscore_score_patch,
"dino_cotrain":dino_cotrain_score,
"dino_multi_cotrain": dino_multi_cotrain_score,
"dino_patch_cotrain": dino_patch_cotrain_score,
"siglip_cotrain": siglip_cotrain_score,
"siglip_image_similarity": siglip_image_similarity_score
}
score_fns={}
for score_name, weight in score_dict.items():
# import pdb; pdb.set_trace()
score_fns[score_name] = score_functions[score_name](device) if 'device' in score_functions[score_name].__code__.co_varnames else score_functions[score_name]()
# only_strict is only for geneval. During training, only the strict reward is needed, and non-strict rewards don't need to be computed, reducing reward calculation time.
def _fn(images, prompts, metadata, scorer = None, ref_images=None, only_strict=True, head=None, fusion=None, layer_ids = None, temperature=0.2):
total_scores = []
score_details = {}
for score_name, weight in score_dict.items():
if score_name == "geneval":
scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name](images, prompts, metadata, only_strict)
score_details['accuracy'] = rewards
score_details['strict_accuracy'] = strict_rewards
for key, value in group_strict_rewards.items():
score_details[f'{key}_strict_accuracy'] = value
for key, value in group_rewards.items():
score_details[f'{key}_accuracy'] = value
elif score_name == "image_similarity":
scores, rewards = score_fns[score_name](images, ref_images)
elif score_name == "siglip_image_similarity":
scores, rewards = score_fns[score_name](images, ref_images)
elif score_name == "image_similarity_eval":
scores, rewards, feat, ref_feat = score_fns[score_name](images, ref_images)
score_details['feat'] = feat
score_details['ref_feat'] = ref_feat
elif score_name == "constractive_external":
scores, rewards = score_fns[score_name](images, prompts, ref_images)
elif score_name == "discriminator":
scores, rewards = score_fns[score_name](scorer, images, prompts, ref_images)
elif score_name == "pickscore_cotrain":
scores, rewards = score_fns[score_name](scorer, images, prompts, metadata)
elif score_name == "dino_cotrain":
scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata)
elif score_name == "siglip_cotrain":
scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata)
elif score_name == "dino_multi_cotrain":
scores, rewards = score_fns[score_name](scorer, head, fusion, images, prompts, metadata, layer_ids, temperature)
elif score_name == "dino_patch_cotrain":
scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata)
elif score_name == "dinov3_patch_cotrain":
scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata)
else:
scores, rewards = score_fns[score_name](images, prompts, metadata)
score_details[score_name] = scores
weighted_scores = [weight * score for score in scores]
if not total_scores:
total_scores = weighted_scores
else:
total_scores = [total + weighted for total, weighted in zip(total_scores, weighted_scores)]
# import pdb; pdb.set_trace()
score_details['avg'] = total_scores
return score_details, {}
return _fn
def main():
import torchvision.transforms as transforms
image_paths = [
"nasa.jpg",
]
transform = transforms.Compose([
transforms.ToTensor(), # Convert to tensor
])
images = torch.stack([transform(Image.open(image_path).convert('RGB')) for image_path in image_paths])
prompts=[
'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
]
metadata = {} # Example metadata
score_dict = {
"unifiedreward": 1.0
}
# Initialize the multi_score function with a device and score_dict
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scoring_fn = multi_score(device, score_dict)
# Get the scores
scores, _ = scoring_fn(images, prompts, metadata)
# Print the scores
print("Scores:", scores)
if __name__ == "__main__":
main()