| |
| import numpy as np |
| import cv2 |
| import PIL.Image |
| from scipy.interpolate import griddata |
| import h5py |
| from utils import azi_diff |
| from tqdm import tqdm |
| import os |
| import random |
| import pickle |
| import logging |
| import joblib |
|
|
| def get_image_files(directory): |
| image_extensions = {'.jpg', '.jpeg', '.png'} |
| image_files = [] |
| for root, _, files in os.walk(directory): |
| for file in files: |
| if os.path.splitext(file)[1].lower() in image_extensions: |
| image_files.append(os.path.join(root, file)) |
| return image_files |
|
|
| def load_image_files(class1_dirs, class2_dirs): |
| class1_files = [] |
| for directory in tqdm(class1_dirs): |
| class1_files.extend(get_image_files(directory)) |
|
|
| class2_files = [] |
| for directory in tqdm(class2_dirs): |
| class2_files.extend(get_image_files(directory)) |
|
|
| |
| min_length = min(len(class1_files), len(class2_files)) |
|
|
| random.shuffle(class1_files) |
| random.shuffle(class2_files) |
|
|
| class1_files = class1_files[:min_length] |
| class2_files = class2_files[:min_length] |
|
|
| print(f"Number of files: Real = {len(class1_files)}, Fake = {len(class2_files)}") |
|
|
| return class1_files, class2_files |
|
|
|
|
| def process_and_save_h5(file_label_pairs, patch_num, N, save_interval, joblib_batch_size, output_dir, start_by=0): |
| def process_file(file_label): |
| path, label = file_label |
| try: |
| result = azi_diff(path, patch_num, N) |
| return result, label |
| except Exception as e: |
| logging.error(f"Error processing file {path}: {str(e)}") |
| return None, None |
|
|
| num_files = len(file_label_pairs) |
| num_saves = (num_files - start_by + save_interval - 1) // save_interval |
|
|
| if not os.path.exists(output_dir): |
| os.makedirs(output_dir) |
|
|
| with tqdm(total=num_files - start_by, desc="Processing files", unit="image") as pbar: |
| for save_index in range(num_saves): |
| save_start = start_by + save_index * save_interval |
| save_end = min(save_start + save_interval, num_files) |
| batch_pairs = file_label_pairs[save_start:save_end] |
|
|
| all_rich = [] |
| all_poor = [] |
| all_labels = [] |
| for batch_start in range(0, len(batch_pairs), joblib_batch_size): |
| batch_end = min(batch_start + joblib_batch_size, len(batch_pairs)) |
| small_batch_pairs = batch_pairs[batch_start:batch_end] |
|
|
| processed_data = joblib.Parallel(n_jobs=-1)( |
| joblib.delayed(process_file)(file_label) for file_label in small_batch_pairs |
| ) |
| for data, label in processed_data: |
| if data is not None: |
| all_rich.append(data['total_emb'][0]) |
| all_poor.append(data['total_emb'][1]) |
| all_labels.append(label) |
|
|
| pbar.update(len(small_batch_pairs)) |
|
|
| next_save_start = save_end |
| output_filename = f"{output_dir}/processed_data_{next_save_start}.h5" |
| logging.info(f"Saving {output_filename}") |
|
|
| with h5py.File(output_filename, 'w') as h5file: |
| h5file.create_dataset('rich', data=np.array(all_rich)) |
| h5file.create_dataset('rich', data=np.array(all_poor)) |
| h5file.create_dataset('labels', data=np.array(all_labels)) |
| |
| logging.info(f"Successfully saved {output_filename}") |
|
|
| del all_rich |
| del all_poor |
| del all_labels |
|
|
|
|
|
|
|
|
| load=False |
| class1_dirs = [ |
| "/home/archive/real/", |
| "/home/13k_real/", |
| "/home/AI_detection_dataset/Real_AI_SD_LD_Dataset/train/real/", |
| "/home/AI_detection_dataset/Real_AI_SD_LD_Dataset/test/real/" |
| ] |
|
|
| class2_dirs = [ |
| "/home/archive/fakeV2/fake-v2/", |
| "/home/dalle3/", |
| "/home/AI_detection_dataset/Real_AI_SD_LD_Dataset/train/fake/", |
| "/home/AI_detection_dataset/Real_AI_SD_LD_Dataset/test/fake/" |
| ] |
| output_dir = "/content/drive/MyDrive/h5saves" |
| file_paths_pickle_save_dir='/content/drive/MyDrive/aigc_file_paths.pkl' |
| patch_num = 128 |
| N = 256 |
| save_interval = 2000 |
| joblib_batch_size = 400 |
| start_by = 0 |
|
|
| if load==True: |
| with open(file_paths_pickle_save_dir, 'rb') as file: |
| file_label_pairs=pickle.load(file) |
| print(len(file_label_pairs)) |
| else: |
| class1_files, class2_files = load_image_files(class1_dirs, class2_dirs) |
| file_label_pairs = list(zip(class1_files, [0] * len(class1_files))) + list(zip(class2_files, [1] * len(class2_files))) |
| random.shuffle(file_label_pairs) |
| with open(file_paths_pickle_save_dir, 'wb') as file: |
| pickle.dump(file_label_pairs, file) |
| print(len(file_label_pairs)) |
|
|
| logging.basicConfig(level=logging.INFO) |
| process_and_save_h5(file_label_pairs, patch_num, N, save_interval, joblib_batch_size, output_dir, start_by) |
|
|