import sys import os import torch from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True from class_names import imagenet1k_classnames, facedataset_classnames import json from tqdm import tqdm import pickle sys.path.append("../") from models.clip import clip def zeroshot_CLIP_batch(model, preprocess, device, text_inputs, class_names, image_paths, topk_indexes=5): batch = torch.stack([preprocess(Image.open(path)) for path in image_paths]).to(device) with torch.no_grad(): image_features = model.encode_image(batch) text_features = model.encode_text(text_inputs) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) results = [] for i in range(similarity.size(0)): values, indices = similarity[i].topk(topk_indexes) outputs = [ [class_names[index.item()], round(100 * value.item(), 4)] for value, index in zip(values, indices) ] results.append(outputs) return results def process_images_in_batches( model, preprocess, device, text_inputs, dataset_dir, class_names, image_paths, batch_size, topk_indexes, class_label=None, ): results = {} for i in tqdm( range(0, len(image_paths), batch_size), desc=f"Processing batch of size {batch_size}", ): batch_paths = image_paths[i : i + batch_size] batch_results = zeroshot_CLIP_batch( model, preprocess, device, text_inputs, class_names, batch_paths, topk_indexes, ) for path, result in zip(batch_paths, batch_results): if class_label: result.append([class_label, -1]) results[path.replace(dataset_dir, "")] = result return results def prepare_text_inputs(data_type): if data_type == "CDDB": dataset_structure = [ "whichfaceisreal", "stylegan", "crn", "imle", "cyclegan", "wild", "glow", "deepfake", "san", "stargan_gf", "biggan", "gaugan", ] multiclass = [0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0] humans_inside = [1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0] subsets = ["train", "val"] classes = ["0_real", "1_fake"] return dataset_structure, multiclass, humans_inside, subsets, classes if data_type == "TrueFake": dataset_structure = [ 'Fake/FLUX.1/animals', 'Fake/FLUX.1/faces', 'Fake/FLUX.1/general', 'Fake/FLUX.1/landscapes', 'Fake/StableDiffusion1.5/animals', 'Fake/StableDiffusion1.5/faces', 'Fake/StableDiffusion1.5/general', 'Fake/StableDiffusion1.5/landscapes', 'Fake/StableDiffusion2/animals', 'Fake/StableDiffusion2/faces', 'Fake/StableDiffusion2/general', 'Fake/StableDiffusion2/landscapes', 'Fake/StableDiffusion3/animals', 'Fake/StableDiffusion3/faces', 'Fake/StableDiffusion3/general', 'Fake/StableDiffusion3/landscapes', 'Fake/StableDiffusionXL/animals', 'Fake/StableDiffusionXL/faces', 'Fake/StableDiffusionXL/general', 'Fake/StableDiffusionXL/landscapes', 'Fake/StyleGAN/images-psi-0.5', 'Fake/StyleGAN/images-psi-0.7', 'Fake/StyleGAN2/conf-f-psi-0.5', 'Fake/StyleGAN2/conf-f-psi-1', 'Fake/StyleGAN3/conf-t-psi-0.5', 'Fake/StyleGAN3/conf-t-psi-0.7', 'Real/FFHQ', 'Real/FORLAB' ] multiclass = None humans_inside = [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0] classes = ["Fake", "Real"] subsets = None return dataset_structure, multiclass, humans_inside, subsets, classes else: raise ValueError(f"{data_type} not valid.") # def zeroshot_dataset_batch(dataset_dir, data_type, batch_size=32): # device = "cuda" if torch.cuda.is_available() else "cpu" # model, preprocess = clip.load("ViT-B/16", device) # dataset_structure, multiclass, humans_inside, subsets, classes = ( # prepare_text_inputs(data_type) # ) # results = {} # for index, folder in enumerate(tqdm(dataset_structure, desc="Processing datasets")): # if humans_inside[index] == 0: # text_inputs = torch.cat( # [ # clip.tokenize(f"a photo of a {c}") # for c in imagenet1k_classnames.values() # ] # ).to(device) # class_names = imagenet1k_classnames # else: # text_inputs = torch.cat( # [ # clip.tokenize(f"a photo of a {c}") # for c in facedataset_classnames.values() # ] # ).to(device) # class_names = facedataset_classnames # for subset in subsets: # subset_path = os.path.join(dataset_dir, folder, subset) # if multiclass[index] == 1: # class_labels = os.listdir(subset_path) # for class_label in class_labels: # class_path = os.path.join(subset_path, class_label) # for binary_label in classes: # image_paths = [ # os.path.join(class_path, binary_label, img) # for img in os.listdir( # os.path.join(class_path, binary_label) # ) # ] # batch_results = process_images_in_batches( # model, # preprocess, # device, # text_inputs, # dataset_dir, # class_names, # image_paths, # batch_size, # 5, # class_label, # ) # results.update(batch_results) # else: # for binary_label in classes: # image_paths = [ # os.path.join(subset_path, binary_label, img) # for img in os.listdir(os.path.join(subset_path, binary_label)) # ] # batch_results = process_images_in_batches( # model, # preprocess, # device, # text_inputs, # dataset_dir, # class_names, # image_paths, # batch_size, # 5, # ) # results.update(batch_results) # with open("./DEBUG_classes.json", "w") as f: # only for fast debug # json.dump(results, f, indent=4) # with open("./classes.pkl", "wb") as f: # pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) def zeroshot_dataset_batch(dataset_dir, data_type, batch_size=32): device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/16", device) dataset_structure, multiclass, humans_inside, subsets, classes = ( prepare_text_inputs(data_type) ) with open("./classes.pkl", "rb") as f: results = pickle.load(f) if len(results) == 0: results = {} print(len(results)) for index, folder in enumerate(tqdm(dataset_structure, desc="Processing datasets")): print(f"Processing {folder}") if humans_inside[index] == 0: text_inputs = torch.cat( [ clip.tokenize(f"a photo of a {c}") for c in imagenet1k_classnames.values() ] ).to(device) class_names = imagenet1k_classnames else: text_inputs = torch.cat( [ clip.tokenize(f"a photo of a {c}") for c in facedataset_classnames.values() ] ).to(device) class_names = facedataset_classnames image_paths = [os.path.join(dataset_dir, folder, img) for img in os.listdir(os.path.join(dataset_dir, folder))] if image_paths[0].replace(dataset_dir, "") in results.keys(): continue batch_results = process_images_in_batches( model, preprocess, device, text_inputs, dataset_dir, class_names, image_paths, batch_size, 5, ) results.update(batch_results) # with open("./DEBUG_classes.json", "w") as f: # only for fast debug # json.dump(results, f, indent=4) with open("./classes_nosocial.pkl", "wb") as f: pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) def get_JSON_dataset_batch(): # dataroot = "/home/francesco.laiti/datasets/CDDB/" # datatype = "CDDB" dataroot = "/media/mmlab/Volume2/TrueFake/PreSocial/" datatype = "TrueFake" batch_size = 2048 zeroshot_dataset_batch(dataroot, datatype, batch_size) if __name__ == "__main__": get_JSON_dataset_batch()