Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import torch | |
| from PIL import Image, ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| from class_names import imagenet1k_classnames, facedataset_classnames | |
| import json | |
| from tqdm import tqdm | |
| import pickle | |
| sys.path.append("../") | |
| from models.clip import clip | |
| def zeroshot_CLIP_batch(model, preprocess, device, text_inputs, class_names, image_paths, topk_indexes=5): | |
| batch = torch.stack([preprocess(Image.open(path)) for path in image_paths]).to(device) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(batch) | |
| text_features = model.encode_text(text_inputs) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| results = [] | |
| for i in range(similarity.size(0)): | |
| values, indices = similarity[i].topk(topk_indexes) | |
| outputs = [ | |
| [class_names[index.item()], round(100 * value.item(), 4)] | |
| for value, index in zip(values, indices) | |
| ] | |
| results.append(outputs) | |
| return results | |
| def process_images_in_batches( | |
| model, | |
| preprocess, | |
| device, | |
| text_inputs, | |
| dataset_dir, | |
| class_names, | |
| image_paths, | |
| batch_size, | |
| topk_indexes, | |
| class_label=None, | |
| ): | |
| results = {} | |
| for i in tqdm( | |
| range(0, len(image_paths), batch_size), | |
| desc=f"Processing batch of size {batch_size}", | |
| ): | |
| batch_paths = image_paths[i : i + batch_size] | |
| batch_results = zeroshot_CLIP_batch( | |
| model, | |
| preprocess, | |
| device, | |
| text_inputs, | |
| class_names, | |
| batch_paths, | |
| topk_indexes, | |
| ) | |
| for path, result in zip(batch_paths, batch_results): | |
| if class_label: | |
| result.append([class_label, -1]) | |
| results[path.replace(dataset_dir, "")] = result | |
| return results | |
| def prepare_text_inputs(data_type): | |
| if data_type == "CDDB": | |
| dataset_structure = [ | |
| "whichfaceisreal", | |
| "stylegan", | |
| "crn", | |
| "imle", | |
| "cyclegan", | |
| "wild", | |
| "glow", | |
| "deepfake", | |
| "san", | |
| "stargan_gf", | |
| "biggan", | |
| "gaugan", | |
| ] | |
| multiclass = [0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0] | |
| humans_inside = [1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0] | |
| subsets = ["train", "val"] | |
| classes = ["0_real", "1_fake"] | |
| return dataset_structure, multiclass, humans_inside, subsets, classes | |
| if data_type == "TrueFake": | |
| dataset_structure = [ | |
| 'Fake/FLUX.1/animals', | |
| 'Fake/FLUX.1/faces', | |
| 'Fake/FLUX.1/general', | |
| 'Fake/FLUX.1/landscapes', | |
| 'Fake/StableDiffusion1.5/animals', | |
| 'Fake/StableDiffusion1.5/faces', | |
| 'Fake/StableDiffusion1.5/general', | |
| 'Fake/StableDiffusion1.5/landscapes', | |
| 'Fake/StableDiffusion2/animals', | |
| 'Fake/StableDiffusion2/faces', | |
| 'Fake/StableDiffusion2/general', | |
| 'Fake/StableDiffusion2/landscapes', | |
| 'Fake/StableDiffusion3/animals', | |
| 'Fake/StableDiffusion3/faces', | |
| 'Fake/StableDiffusion3/general', | |
| 'Fake/StableDiffusion3/landscapes', | |
| 'Fake/StableDiffusionXL/animals', | |
| 'Fake/StableDiffusionXL/faces', | |
| 'Fake/StableDiffusionXL/general', | |
| 'Fake/StableDiffusionXL/landscapes', | |
| 'Fake/StyleGAN/images-psi-0.5', | |
| 'Fake/StyleGAN/images-psi-0.7', | |
| 'Fake/StyleGAN2/conf-f-psi-0.5', | |
| 'Fake/StyleGAN2/conf-f-psi-1', | |
| 'Fake/StyleGAN3/conf-t-psi-0.5', | |
| 'Fake/StyleGAN3/conf-t-psi-0.7', | |
| 'Real/FFHQ', | |
| 'Real/FORLAB' | |
| ] | |
| multiclass = None | |
| humans_inside = [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0] | |
| classes = ["Fake", "Real"] | |
| subsets = None | |
| return dataset_structure, multiclass, humans_inside, subsets, classes | |
| else: | |
| raise ValueError(f"{data_type} not valid.") | |
| # def zeroshot_dataset_batch(dataset_dir, data_type, batch_size=32): | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # model, preprocess = clip.load("ViT-B/16", device) | |
| # dataset_structure, multiclass, humans_inside, subsets, classes = ( | |
| # prepare_text_inputs(data_type) | |
| # ) | |
| # results = {} | |
| # for index, folder in enumerate(tqdm(dataset_structure, desc="Processing datasets")): | |
| # if humans_inside[index] == 0: | |
| # text_inputs = torch.cat( | |
| # [ | |
| # clip.tokenize(f"a photo of a {c}") | |
| # for c in imagenet1k_classnames.values() | |
| # ] | |
| # ).to(device) | |
| # class_names = imagenet1k_classnames | |
| # else: | |
| # text_inputs = torch.cat( | |
| # [ | |
| # clip.tokenize(f"a photo of a {c}") | |
| # for c in facedataset_classnames.values() | |
| # ] | |
| # ).to(device) | |
| # class_names = facedataset_classnames | |
| # for subset in subsets: | |
| # subset_path = os.path.join(dataset_dir, folder, subset) | |
| # if multiclass[index] == 1: | |
| # class_labels = os.listdir(subset_path) | |
| # for class_label in class_labels: | |
| # class_path = os.path.join(subset_path, class_label) | |
| # for binary_label in classes: | |
| # image_paths = [ | |
| # os.path.join(class_path, binary_label, img) | |
| # for img in os.listdir( | |
| # os.path.join(class_path, binary_label) | |
| # ) | |
| # ] | |
| # batch_results = process_images_in_batches( | |
| # model, | |
| # preprocess, | |
| # device, | |
| # text_inputs, | |
| # dataset_dir, | |
| # class_names, | |
| # image_paths, | |
| # batch_size, | |
| # 5, | |
| # class_label, | |
| # ) | |
| # results.update(batch_results) | |
| # else: | |
| # for binary_label in classes: | |
| # image_paths = [ | |
| # os.path.join(subset_path, binary_label, img) | |
| # for img in os.listdir(os.path.join(subset_path, binary_label)) | |
| # ] | |
| # batch_results = process_images_in_batches( | |
| # model, | |
| # preprocess, | |
| # device, | |
| # text_inputs, | |
| # dataset_dir, | |
| # class_names, | |
| # image_paths, | |
| # batch_size, | |
| # 5, | |
| # ) | |
| # results.update(batch_results) | |
| # with open("./DEBUG_classes.json", "w") as f: # only for fast debug | |
| # json.dump(results, f, indent=4) | |
| # with open("./classes.pkl", "wb") as f: | |
| # pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) | |
| def zeroshot_dataset_batch(dataset_dir, data_type, batch_size=32): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-B/16", device) | |
| dataset_structure, multiclass, humans_inside, subsets, classes = ( | |
| prepare_text_inputs(data_type) | |
| ) | |
| with open("./classes.pkl", "rb") as f: | |
| results = pickle.load(f) | |
| if len(results) == 0: | |
| results = {} | |
| print(len(results)) | |
| for index, folder in enumerate(tqdm(dataset_structure, desc="Processing datasets")): | |
| print(f"Processing {folder}") | |
| if humans_inside[index] == 0: | |
| text_inputs = torch.cat( | |
| [ | |
| clip.tokenize(f"a photo of a {c}") | |
| for c in imagenet1k_classnames.values() | |
| ] | |
| ).to(device) | |
| class_names = imagenet1k_classnames | |
| else: | |
| text_inputs = torch.cat( | |
| [ | |
| clip.tokenize(f"a photo of a {c}") | |
| for c in facedataset_classnames.values() | |
| ] | |
| ).to(device) | |
| class_names = facedataset_classnames | |
| image_paths = [os.path.join(dataset_dir, folder, img) for img in os.listdir(os.path.join(dataset_dir, folder))] | |
| if image_paths[0].replace(dataset_dir, "") in results.keys(): | |
| continue | |
| batch_results = process_images_in_batches( | |
| model, | |
| preprocess, | |
| device, | |
| text_inputs, | |
| dataset_dir, | |
| class_names, | |
| image_paths, | |
| batch_size, | |
| 5, | |
| ) | |
| results.update(batch_results) | |
| # with open("./DEBUG_classes.json", "w") as f: # only for fast debug | |
| # json.dump(results, f, indent=4) | |
| with open("./classes_nosocial.pkl", "wb") as f: | |
| pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) | |
| def get_JSON_dataset_batch(): | |
| # dataroot = "/home/francesco.laiti/datasets/CDDB/" | |
| # datatype = "CDDB" | |
| dataroot = "/media/mmlab/Volume2/TrueFake/PreSocial/" | |
| datatype = "TrueFake" | |
| batch_size = 2048 | |
| zeroshot_dataset_batch(dataroot, datatype, batch_size) | |
| if __name__ == "__main__": | |
| get_JSON_dataset_batch() | |