Spaces:
Build error
Build error
| import os, sys, json, math, argparse, glob | |
| from pathlib import Path | |
| from typing import List | |
| import torch | |
| from PIL import Image | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from transformers import ( | |
| AutoProcessor, CLIPModel, | |
| AutoImageProcessor, AutoModel | |
| ) | |
| from datasets import load_dataset | |
| def scale_bbox(bbox, ori_size, target_size): | |
| x_min, y_min, x_max, y_max = bbox | |
| ori_width, ori_height = ori_size | |
| target_width, target_height = target_size | |
| width_ratio = target_width / ori_width | |
| height_ratio = target_height / ori_height | |
| scaled_x_min = int(x_min * width_ratio) | |
| scaled_y_min = int(y_min * height_ratio) | |
| scaled_x_max = int(x_max * width_ratio) | |
| scaled_y_max = int(y_max * height_ratio) | |
| scaled_x_min = max(0, scaled_x_min) | |
| scaled_y_min = max(0, scaled_y_min) | |
| scaled_x_max = min(target_width, scaled_x_max) | |
| scaled_y_max = min(target_height, scaled_y_max) | |
| return [scaled_x_min, scaled_y_min, scaled_x_max, scaled_y_max] | |
| def encode_clip(imgs: List[Image.Image]) -> torch.Tensor: | |
| features_list = [] | |
| for img in imgs: | |
| inputs = clip_processor(images=img, return_tensors="pt").to(device) | |
| image_features = clip_model.get_image_features(**inputs) | |
| normalized_features = image_features / image_features.norm(dim=1, keepdim=True) | |
| features_list.append(normalized_features.squeeze().cpu()) | |
| return torch.stack(features_list) | |
| def encode_dino(imgs: List[Image.Image]) -> torch.Tensor: | |
| features_list = [] | |
| for img in imgs: | |
| inputs = dino_processor(images=img, return_tensors="pt").to(device) | |
| outputs = dino_model(**inputs) | |
| image_features = outputs.last_hidden_state.mean(dim=1) | |
| normalized_features = image_features / image_features.norm(dim=1, keepdim=True) | |
| features_list.append(normalized_features.squeeze().cpu()) | |
| return torch.stack(features_list) | |
| def cosine(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| return (a @ b.T).squeeze() | |
| # ------------- Command line arguments ----------------- | |
| parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
| parser.add_argument("--benchmark_repo", type=str, default="HuiZhang0812/CreatiDesign_benchmark", | |
| help="Root directory for one thousand cases") | |
| parser.add_argument("--gen_root", type=str, default="outputs/CreatiDesign_benchmark", | |
| help="Root directory for generated images (should have images/<case_id>.jpg underneath)") | |
| parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) | |
| parser.add_argument("--outfile", type=str, | |
| help="Path for result CSV; by default written to gen_root") | |
| args = parser.parse_args() | |
| print("handling:", args.gen_root) | |
| if args.outfile is None: | |
| args.outfile = os.path.join(args.gen_root,"scores.csv") | |
| # Convert outfile to Path object | |
| outfile_path = Path(args.outfile) | |
| device = torch.device(args.device if torch.cuda.is_available() else "cpu") | |
| print(f"[INFO] Using device: {device}") | |
| # ------------- Loading models ------------------- | |
| print("[INFO] loading CLIP...") | |
| clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| clip_model.eval() | |
| print("[INFO] loading DINOv2...") | |
| dino_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') | |
| dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device) | |
| dino_model.eval() | |
| benchmark = load_dataset(args.benchmark_repo, split="test") | |
| DEBUG = True | |
| if DEBUG: | |
| subject_save_roor = os.path.join(args.gen_root,"subject-eval-visual") | |
| os.makedirs(subject_save_roor,exist_ok=True) | |
| records = [] | |
| for case in tqdm(benchmark): | |
| json_data = json.loads(case["metadata"]) | |
| case_info = json_data["img_info"] | |
| case_id = case_info["img_id"] | |
| # ---------- Read reference subjects ---------- | |
| ref_imgs = case['condition_white_variants'] | |
| if len(ref_imgs) == 0: | |
| print(f"[WARN] {case_id} has no reference subject, skipping") | |
| continue | |
| # ---------- Read generated image ---------- | |
| gen_path = os.path.join(args.gen_root, "images", f"{case_id}.jpg") | |
| gen_img = Image.open(gen_path).convert("RGB") | |
| # Get width and height of generated image | |
| gen_width, gen_height = gen_img.size | |
| reg_bbox_id = [item["bbox_idx"] for item in sorted(json_data["subject_annotations"], key=lambda x: x["bbox_idx"])] | |
| ref_bbox = [item["bbox"] for item in sorted(json_data["subject_annotations"], key=lambda x: x["bbox_idx"])] | |
| ori_width,ori_height = json_data["img_info"]["img_width"],json_data["img_info"]["img_height"] | |
| # Extract corresponding images from the generated image | |
| gen_imgs = [] | |
| for bbox in ref_bbox: | |
| # Scale the bounding box | |
| scaled_bbox = scale_bbox( | |
| bbox, | |
| (ori_width, ori_height), | |
| (gen_width, gen_height) | |
| ) | |
| # Crop the image area | |
| x_min, y_min, x_max, y_max = scaled_bbox | |
| cropped_img = gen_img.crop((x_min, y_min, x_max, y_max)) | |
| gen_imgs.append(cropped_img) | |
| if DEBUG: | |
| folder_root = os.path.join(subject_save_roor,case_id) | |
| os.makedirs(folder_root,exist_ok=True) | |
| # Save cropped images | |
| for i, (img, img_id) in enumerate(zip(gen_imgs, reg_bbox_id)): | |
| img.save(os.path.join(folder_root, f"{img_id}.png")) | |
| # ---------- Features ---------- | |
| ref_clip = encode_clip(ref_imgs) # (n,dim) | |
| gen_clip = encode_clip(gen_imgs) # (n,dim) | |
| ref_dino = encode_dino(ref_imgs) # (n,dim) | |
| gen_dino = encode_dino(gen_imgs) # (n,dim) | |
| # ---------- Similarity ---------- | |
| clip_sims = torch.nn.functional.cosine_similarity(ref_clip, gen_clip) | |
| dino_sims = torch.nn.functional.cosine_similarity(ref_dino, gen_dino) | |
| clip_i = clip_sims.mean().item() | |
| dino_avg = dino_sims.mean().item() | |
| m_dino = dino_sims.prod().item() | |
| records.append(dict( | |
| case_id=case_id, | |
| num_subject=len(ref_imgs), | |
| clip_i=clip_i, | |
| dino=dino_avg, | |
| m_dino=m_dino | |
| )) | |
| # ---------------- Result statistics ----------------- | |
| df = pd.DataFrame(records).sort_values("case_id") | |
| overall = df[["clip_i","dino","m_dino"]].mean().to_dict() | |
| print("\n========== Overall Average ==========") | |
| for k,v in overall.items(): | |
| print(f"{k:>8}: {v:.6f}") | |
| print("=====================================\n") | |
| # Group by number of subjects | |
| df_by_subjects = {} | |
| avg_by_subjects = {} | |
| # Create subset for each subject count (1-5) | |
| for i in range(1, 6): | |
| # Filter records with subject count = i | |
| subset = df[df["num_subject"] == i] | |
| if len(subset) > 0: | |
| # Calculate average for this group | |
| subset_avg = subset[["clip_i", "dino", "m_dino"]].mean().to_dict() | |
| avg_by_subjects[i] = subset_avg | |
| # Create subset with average row | |
| avg_row = {"case_id": f"average_subject_{i}", "num_subject": i} | |
| avg_row.update(subset_avg) | |
| # Add average row to subset | |
| subset_with_avg = pd.concat([subset, pd.DataFrame([avg_row])], ignore_index=True) | |
| df_by_subjects[i] = subset_with_avg | |
| # Print average for this group | |
| print(f"\n=== Subject {i} Average (n={len(subset)}) ===") | |
| for k, v in subset_avg.items(): | |
| print(f"{k:>8}: {v:.6f}") | |
| # Save subset - fixed path handling | |
| subject_path = outfile_path.parent / f"{outfile_path.stem}_subject{i}_location_prior{outfile_path.suffix}" | |
| subset_with_avg.to_csv(subject_path, index=False, float_format="%.6f") | |
| print(f"[INFO] Subject {i} results written to {subject_path}") | |
| # Save overall average to CSV - fixed path handling | |
| overall_df = pd.DataFrame([overall], index=["overall"]) | |
| overall_path = outfile_path.parent / f"{outfile_path.stem}_overall_location_prior{outfile_path.suffix}" | |
| overall_df.to_csv(overall_path, float_format="%.6f") | |
| print(f"[INFO] Overall results written to {overall_path}") | |
| # Write CSV | |
| df.to_csv(args.outfile, index=False, float_format="%.6f") | |
| print(f"[INFO] Written to {args.outfile}") | |
| # Create statistics table with averages for all groups | |
| if avg_by_subjects: | |
| # Merge averages for each group into one table | |
| stats_rows = [] | |
| for num_subject, avg_dict in avg_by_subjects.items(): | |
| row = {"num_subject": num_subject} | |
| row.update(avg_dict) | |
| stats_rows.append(row) | |
| # Add overall average | |
| overall_row = {"num_subject": "all"} | |
| overall_row.update(overall) | |
| stats_rows.append(overall_row) | |
| # Create summary statistics table | |
| stats_df = pd.DataFrame(stats_rows) | |
| # Fixed path handling | |
| stats_path = outfile_path.parent / f"{outfile_path.stem}_stats_location_prior{outfile_path.suffix}" | |
| stats_df.to_csv(stats_path, index=False, float_format="%.6f") | |
| print(f"[INFO] All group statistics written to {stats_path}") | |