| import json |
| import os |
| import sys |
| from dataclasses import dataclass, field |
| from glob import glob |
| from typing import Mapping |
|
|
| from PIL import Image |
| from tqdm import tqdm |
|
|
| from laion_face_common import generate_annotation |
|
|
|
|
| @dataclass |
| class RunProgress: |
| pending: list = field(default_factory=list) |
| success: list = field(default_factory=list) |
| skipped_size: list = field(default_factory=list) |
| skipped_nsfw: list = field(default_factory=list) |
| skipped_noface: list = field(default_factory=list) |
| skipped_smallface: list = field(default_factory=list) |
|
|
|
|
| def main( |
| status_filename: str, |
| prompt_filename: str, |
| input_glob: str, |
| output_directory: str, |
| annotated_output_directory: str = "", |
| min_image_size: int = 384, |
| max_image_size: int = 32766, |
| min_face_size_pixels: int = 64, |
| prompt_mapping: dict = None, |
| ): |
| status = RunProgress() |
|
|
| if os.path.exists(status_filename): |
| print("Continuing from checkpoint.") |
| |
| status_temp = json.load(open(status_filename, 'rt')) |
| for k in status.__dict__.keys(): |
| status.__setattr__(k, status_temp[k]) |
| |
| pout = open(prompt_filename, 'at') |
| else: |
| print("Starting run.") |
| status = RunProgress() |
| status.pending = list(glob(input_glob)) |
| |
| pout = open(prompt_filename, 'wt') |
| with open(status_filename, 'wt') as fout: |
| json.dump(status.__dict__, fout) |
|
|
| print(f"{len(status.pending)} images remaining") |
|
|
| |
| |
| if prompt_mapping is None: |
| prompt_mapping = dict() |
|
|
| step = 0 |
| with tqdm(total=len(status.pending)) as pbar: |
| while len(status.pending) > 0: |
| full_filename = status.pending.pop() |
| pbar.update(1) |
| step += 1 |
|
|
| if step % 100 == 0: |
| |
| with open(status_filename, 'wt') as fout: |
| json.dump(status.__dict__, fout) |
|
|
| _fpath, fname = os.path.split(full_filename) |
|
|
| |
| |
| |
| annotation_filename = "" |
| if annotated_output_directory: |
| annotation_filename = os.path.join(annotated_output_directory, fname) |
| output_filename = os.path.join(output_directory, fname) |
|
|
| |
| partial_filename, extension = os.path.splitext(full_filename) |
| candidate_json_fullpath = partial_filename + ".json" |
| image_metadata = {} |
| if os.path.exists(candidate_json_fullpath): |
| try: |
| image_metadata = json.load(open(candidate_json_fullpath, 'rt')) |
| except Exception as e: |
| print(e) |
| if "NSFW" in image_metadata: |
| nsfw_marker = image_metadata.get("NSFW") |
| if nsfw_marker is not None and nsfw_marker.lower() != "unlikely": |
| |
| status.skipped_nsfw.append(full_filename) |
| continue |
|
|
| |
| image_prompt = image_metadata.get("caption", prompt_mapping.get(fname, "")) |
|
|
| |
| img = Image.open(full_filename).convert("RGB") |
| img_width = img.size[0] |
| img_height = img.size[1] |
| img_size = min(img.size[0], img.size[1]) |
| if img_size < min_image_size or max(img_width, img_height) > max_image_size: |
| status.skipped_size.append(full_filename) |
| continue |
|
|
| |
| empty, annotated, faces_before_filtering, faces_after_filtering = generate_annotation( |
| img, |
| max_faces=5, |
| min_face_size_pixels=min_face_size_pixels, |
| return_annotation_data=True |
| ) |
| if faces_before_filtering == 0: |
| |
| status.skipped_noface.append(full_filename) |
| continue |
| if faces_after_filtering == 0: |
| |
| status.skipped_smallface.append(full_filename) |
| continue |
|
|
| Image.fromarray(empty).save(output_filename) |
| if annotation_filename: |
| Image.fromarray(annotated).save(annotation_filename) |
|
|
| |
| |
| |
| |
| |
| pout.write(json.dumps({ |
| "source": os.path.join(output_directory, fname), |
| "target": full_filename, |
| "prompt": image_prompt, |
| }) + "\n") |
| pout.flush() |
| status.success.append(full_filename) |
|
|
| |
| with open(status_filename, 'wt') as fout: |
| json.dump(status.__dict__, fout) |
|
|
| pout.close() |
| print("Done!") |
| print(f"{len(status.success)} images added to dataset.") |
| print(f"{len(status.skipped_size)} images rejected for size.") |
| print(f"{len(status.skipped_smallface)} images rejected for having faces too small.") |
| print(f"{len(status.skipped_noface)} images rejected for not having faces.") |
| print(f"{len(status.skipped_nsfw)} images rejected for NSFW.") |
|
|
|
|
| if __name__ == "__main__": |
| if len(sys.argv) >= 3 and "-h" not in sys.argv: |
| prompt_jsonl = sys.argv[1] |
| in_glob = sys.argv[2] |
| output_dir = sys.argv[3] |
| annotation_dir = "" |
| if len(sys.argv) > 4: |
| annotation_dir = sys.argv[4] |
| main("generate_face_poses_checkpoint.json", prompt_jsonl, in_glob, output_dir, annotation_dir) |
| else: |
| print(f"""Usage: |
| python {sys.argv[0]} prompt.jsonl target/*.jpg source/ [annotated/] |
| source and target are slightly confusing in this context. We are writing the image names to prompt.jsonl, so |
| the naming system has to be consistent with what ControlNet expects. In ControlNet, the source is the input and |
| target is the output. We are generating source images from targets in this application, so the second argument |
| should be a folder full of images. The third argument should be 'source', where the images should be places. |
| Optionally, an 'annotated' directory can be provided. Augmented images will be placed here. |
| |
| A checkpoint file named 'generate_face_poses_checkpoint.json' will be created in the place where the script is |
| run. If a run is cancelled, it can be resumed from this checkpoint. |
| |
| If invoking the script from bash, do not forget to enclose globs with quotes. Example usage: |
| `python ./tool_generate_face_poses.py ./face_prompt.jsonl "/home/josephcatrambone/training_data/data-mscoco/images/train2017/*" /home/josephcatrambone/training_data/data-mscoco/images/source_2017/` |
| """) |
|
|