| import torch.multiprocessing as multiprocessing |
| import torchvision.transforms as transforms |
| from torch import autocast |
| from torch.utils.data import Dataset, DataLoader |
| from PIL import Image |
| import torch |
| from torchvision.transforms import InterpolationMode |
| from tqdm import tqdm |
| import json |
| import os |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.autograd.set_detect_anomaly(False) |
| torch.autograd.profiler.emit_nvtx(enabled=False) |
| torch.autograd.profiler.profile(enabled=False) |
| torch.backends.cudnn.benchmark = True |
|
|
|
|
| class ImageDataset(Dataset): |
| def __init__(self, image_folder_path, allowed_extensions): |
| self.allowed_extensions = allowed_extensions |
| self.all_image_paths, self.all_image_names, self.image_base_paths = self.get_image_paths(image_folder_path) |
| self.train_size = len(self.all_image_paths) |
| print(f"Number of images to be tagged: {self.train_size}") |
| self.thin_transform = transforms.Compose([ |
| transforms.Resize(448, interpolation=InterpolationMode.BICUBIC), |
| transforms.CenterCrop(448), |
| transforms.ToTensor(), |
| |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| self.normal_transform = transforms.Compose([ |
| transforms.Resize((448, 448), interpolation=InterpolationMode.BICUBIC), |
| transforms.ToTensor(), |
| |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
| ]) |
|
|
| def get_image_paths(self, folder_path): |
| image_paths = [] |
| image_file_names = [] |
| image_base_paths = [] |
| for root, dirs, files in os.walk(folder_path): |
| for file in files: |
| if file.lower().split(".")[-1] in self.allowed_extensions: |
| image_paths.append((os.path.abspath(os.path.join(root, file)))) |
| image_file_names.append(file.split(".")[0]) |
| image_base_paths.append(root) |
| return image_paths, image_file_names, image_base_paths |
|
|
| def __len__(self): |
| return len(self.all_image_paths) |
|
|
| def __getitem__(self, index): |
| image = Image.open(self.all_image_paths[index]).convert("RGB") |
| ratio = image.height / image.width |
| if ratio > 2.0 or ratio < 0.5: |
| image = self.thin_transform(image) |
| else: |
| image = self.normal_transform(image) |
|
|
| return { |
| 'image': image, |
| "image_name": self.all_image_names[index], |
| "image_root": self.image_base_paths[index] |
| } |
|
|
|
|
| def prepare_model(model_path: str): |
| model = torch.load(model_path) |
| model.to(memory_format=torch.channels_last) |
| model = model.eval() |
| return model |
|
|
|
|
| def train(tagging_is_running, model, dataloader, train_data, output_queue): |
| print('Begin tagging') |
| model.eval() |
| counter = 0 |
|
|
| with torch.no_grad(): |
| for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)): |
| this_data = data['image'].to("cuda") |
| with autocast(device_type='cuda', dtype=torch.bfloat16): |
| outputs = model(this_data) |
|
|
| probabilities = torch.nn.functional.sigmoid(outputs) |
| output_queue.put((probabilities.to("cpu"), data["image_name"], data["image_root"])) |
|
|
| counter += 1 |
| _ = tagging_is_running.get() |
| print("Tagging finished!") |
|
|
|
|
| def tag_writer(tagging_is_running, output_queue, threshold): |
| with open("tags_8034.json", "r") as f: |
| tags = json.load(f) |
| tags.append("placeholder0") |
| tags = sorted(tags) |
| tag_count = len(tags) |
| assert tag_count == 8035, f"The length of tag list is not correct. Correct: 8035, current: {tag_count}" |
|
|
| while not (tagging_is_running.qsize() > 0 and output_queue.qsize() > 0): |
| tag_probabilities, image_names, image_roots = output_queue.get() |
| tag_probabilities = tag_probabilities.tolist() |
|
|
| for per_image_tag_probabilities, image_name, image_root in zip(tag_probabilities, image_names, image_roots, |
| strict=True): |
| this_image_tags = [] |
| this_image_tag_probabilities = [] |
| for index, per_tag_probability in enumerate(per_image_tag_probabilities): |
| if per_tag_probability > threshold: |
| tag = allowed_tags[index] |
| if "placeholder" not in tag: |
| this_image_tags.append(tag) |
| this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000))) |
| output_file = os.path.join(image_root, os.path.splitext(image_name)[0] + ".txt") |
| with open(output_file, "w", encoding="utf-8") as this_output: |
| |
| use_comma_sep = True |
| sep = " " |
| if use_comma_sep: |
| sep = ", " |
| |
| remove_underscores = True |
| if remove_underscores: |
| this_image_tags = map(lambda e: e.replace('_', ' '), this_image_tags) |
| this_output.write(sep.join(this_image_tags)) |
| |
| output_probabilities = False |
| if output_probabilities: |
| this_output.write("\n") |
| this_output.write(sep.join(this_image_tag_probabilities)) |
|
|
|
|
| def main(): |
| image_folder_path = "/path/to/img/folder" |
| |
| |
| model_path = "/path/to/your/model.pth" |
| allowed_extensions = {"jpg", "jpeg", "png", "webp"} |
| batch_size = 64 |
| |
| threshold = 0.3 |
|
|
| multiprocessing.set_start_method('spawn') |
| output_queue = multiprocessing.Queue() |
| tagging_is_running = multiprocessing.Queue(maxsize=5) |
| tagging_is_running.put("Running!") |
|
|
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA is not available!") |
|
|
| model = prepare_model(model_path).to("cuda") |
|
|
| dataset = ImageDataset(image_folder_path, allowed_extensions) |
|
|
| batched_loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=12, |
| pin_memory=True, |
| drop_last=False, |
| ) |
| process_writer = multiprocessing.Process(target=tag_writer, |
| args=(tagging_is_running, output_queue, threshold)) |
| process_writer.start() |
| process_tagger = multiprocessing.Process(target=train, |
| args=(tagging_is_running, model, batched_loader, dataset, output_queue,)) |
| process_tagger.start() |
| process_writer.join() |
| process_tagger.join() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|