maddigit's picture
Upload 27 files
ddbdbca verified
import os, sys, json, math, argparse, glob
from pathlib import Path
from typing import List
import torch
from PIL import Image
import pandas as pd
from tqdm import tqdm
from transformers import (
AutoProcessor, CLIPModel,
AutoImageProcessor, AutoModel
)
from datasets import load_dataset
def scale_bbox(bbox, ori_size, target_size):
x_min, y_min, x_max, y_max = bbox
ori_width, ori_height = ori_size
target_width, target_height = target_size
width_ratio = target_width / ori_width
height_ratio = target_height / ori_height
scaled_x_min = int(x_min * width_ratio)
scaled_y_min = int(y_min * height_ratio)
scaled_x_max = int(x_max * width_ratio)
scaled_y_max = int(y_max * height_ratio)
scaled_x_min = max(0, scaled_x_min)
scaled_y_min = max(0, scaled_y_min)
scaled_x_max = min(target_width, scaled_x_max)
scaled_y_max = min(target_height, scaled_y_max)
return [scaled_x_min, scaled_y_min, scaled_x_max, scaled_y_max]
@torch.no_grad()
def encode_clip(imgs: List[Image.Image]) -> torch.Tensor:
features_list = []
for img in imgs:
inputs = clip_processor(images=img, return_tensors="pt").to(device)
image_features = clip_model.get_image_features(**inputs)
normalized_features = image_features / image_features.norm(dim=1, keepdim=True)
features_list.append(normalized_features.squeeze().cpu())
return torch.stack(features_list)
@torch.no_grad()
def encode_dino(imgs: List[Image.Image]) -> torch.Tensor:
features_list = []
for img in imgs:
inputs = dino_processor(images=img, return_tensors="pt").to(device)
outputs = dino_model(**inputs)
image_features = outputs.last_hidden_state.mean(dim=1)
normalized_features = image_features / image_features.norm(dim=1, keepdim=True)
features_list.append(normalized_features.squeeze().cpu())
return torch.stack(features_list)
@torch.no_grad()
def cosine(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return (a @ b.T).squeeze()
# ------------- Command line arguments -----------------
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--benchmark_repo", type=str, default="HuiZhang0812/CreatiDesign_benchmark",
help="Root directory for one thousand cases")
parser.add_argument("--gen_root", type=str, default="outputs/CreatiDesign_benchmark",
help="Root directory for generated images (should have images/<case_id>.jpg underneath)")
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
parser.add_argument("--outfile", type=str,
help="Path for result CSV; by default written to gen_root")
args = parser.parse_args()
print("handling:", args.gen_root)
if args.outfile is None:
args.outfile = os.path.join(args.gen_root,"scores.csv")
# Convert outfile to Path object
outfile_path = Path(args.outfile)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")
# ------------- Loading models -------------------
print("[INFO] loading CLIP...")
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_model.eval()
print("[INFO] loading DINOv2...")
dino_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
dino_model.eval()
benchmark = load_dataset(args.benchmark_repo, split="test")
DEBUG = True
if DEBUG:
subject_save_roor = os.path.join(args.gen_root,"subject-eval-visual")
os.makedirs(subject_save_roor,exist_ok=True)
records = []
for case in tqdm(benchmark):
json_data = json.loads(case["metadata"])
case_info = json_data["img_info"]
case_id = case_info["img_id"]
# ---------- Read reference subjects ----------
ref_imgs = case['condition_white_variants']
if len(ref_imgs) == 0:
print(f"[WARN] {case_id} has no reference subject, skipping")
continue
# ---------- Read generated image ----------
gen_path = os.path.join(args.gen_root, "images", f"{case_id}.jpg")
gen_img = Image.open(gen_path).convert("RGB")
# Get width and height of generated image
gen_width, gen_height = gen_img.size
reg_bbox_id = [item["bbox_idx"] for item in sorted(json_data["subject_annotations"], key=lambda x: x["bbox_idx"])]
ref_bbox = [item["bbox"] for item in sorted(json_data["subject_annotations"], key=lambda x: x["bbox_idx"])]
ori_width,ori_height = json_data["img_info"]["img_width"],json_data["img_info"]["img_height"]
# Extract corresponding images from the generated image
gen_imgs = []
for bbox in ref_bbox:
# Scale the bounding box
scaled_bbox = scale_bbox(
bbox,
(ori_width, ori_height),
(gen_width, gen_height)
)
# Crop the image area
x_min, y_min, x_max, y_max = scaled_bbox
cropped_img = gen_img.crop((x_min, y_min, x_max, y_max))
gen_imgs.append(cropped_img)
if DEBUG:
folder_root = os.path.join(subject_save_roor,case_id)
os.makedirs(folder_root,exist_ok=True)
# Save cropped images
for i, (img, img_id) in enumerate(zip(gen_imgs, reg_bbox_id)):
img.save(os.path.join(folder_root, f"{img_id}.png"))
# ---------- Features ----------
ref_clip = encode_clip(ref_imgs) # (n,dim)
gen_clip = encode_clip(gen_imgs) # (n,dim)
ref_dino = encode_dino(ref_imgs) # (n,dim)
gen_dino = encode_dino(gen_imgs) # (n,dim)
# ---------- Similarity ----------
clip_sims = torch.nn.functional.cosine_similarity(ref_clip, gen_clip)
dino_sims = torch.nn.functional.cosine_similarity(ref_dino, gen_dino)
clip_i = clip_sims.mean().item()
dino_avg = dino_sims.mean().item()
m_dino = dino_sims.prod().item()
records.append(dict(
case_id=case_id,
num_subject=len(ref_imgs),
clip_i=clip_i,
dino=dino_avg,
m_dino=m_dino
))
# ---------------- Result statistics -----------------
df = pd.DataFrame(records).sort_values("case_id")
overall = df[["clip_i","dino","m_dino"]].mean().to_dict()
print("\n========== Overall Average ==========")
for k,v in overall.items():
print(f"{k:>8}: {v:.6f}")
print("=====================================\n")
# Group by number of subjects
df_by_subjects = {}
avg_by_subjects = {}
# Create subset for each subject count (1-5)
for i in range(1, 6):
# Filter records with subject count = i
subset = df[df["num_subject"] == i]
if len(subset) > 0:
# Calculate average for this group
subset_avg = subset[["clip_i", "dino", "m_dino"]].mean().to_dict()
avg_by_subjects[i] = subset_avg
# Create subset with average row
avg_row = {"case_id": f"average_subject_{i}", "num_subject": i}
avg_row.update(subset_avg)
# Add average row to subset
subset_with_avg = pd.concat([subset, pd.DataFrame([avg_row])], ignore_index=True)
df_by_subjects[i] = subset_with_avg
# Print average for this group
print(f"\n=== Subject {i} Average (n={len(subset)}) ===")
for k, v in subset_avg.items():
print(f"{k:>8}: {v:.6f}")
# Save subset - fixed path handling
subject_path = outfile_path.parent / f"{outfile_path.stem}_subject{i}_location_prior{outfile_path.suffix}"
subset_with_avg.to_csv(subject_path, index=False, float_format="%.6f")
print(f"[INFO] Subject {i} results written to {subject_path}")
# Save overall average to CSV - fixed path handling
overall_df = pd.DataFrame([overall], index=["overall"])
overall_path = outfile_path.parent / f"{outfile_path.stem}_overall_location_prior{outfile_path.suffix}"
overall_df.to_csv(overall_path, float_format="%.6f")
print(f"[INFO] Overall results written to {overall_path}")
# Write CSV
df.to_csv(args.outfile, index=False, float_format="%.6f")
print(f"[INFO] Written to {args.outfile}")
# Create statistics table with averages for all groups
if avg_by_subjects:
# Merge averages for each group into one table
stats_rows = []
for num_subject, avg_dict in avg_by_subjects.items():
row = {"num_subject": num_subject}
row.update(avg_dict)
stats_rows.append(row)
# Add overall average
overall_row = {"num_subject": "all"}
overall_row.update(overall)
stats_rows.append(overall_row)
# Create summary statistics table
stats_df = pd.DataFrame(stats_rows)
# Fixed path handling
stats_path = outfile_path.parent / f"{outfile_path.stem}_stats_location_prior{outfile_path.suffix}"
stats_df.to_csv(stats_path, index=False, float_format="%.6f")
print(f"[INFO] All group statistics written to {stats_path}")