Upload DiffSketcher/run_painterly_render.py with huggingface_hub
Browse files
DiffSketcher/run_painterly_render.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Author: ximing
|
| 3 |
+
# Description: the main func of this project.
|
| 4 |
+
# Copyright (c) 2023, XiMing Xing.
|
| 5 |
+
# License: MIT License
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import argparse
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import random
|
| 12 |
+
from typing import Any, List
|
| 13 |
+
from functools import partial
|
| 14 |
+
|
| 15 |
+
from accelerate.utils import set_seed
|
| 16 |
+
import omegaconf
|
| 17 |
+
|
| 18 |
+
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
|
| 19 |
+
|
| 20 |
+
from libs.engine import merge_and_update_config
|
| 21 |
+
from libs.utils.argparse import accelerate_parser, base_data_parser
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def render_batch_wrap(args: omegaconf.DictConfig,
|
| 25 |
+
seed_range: List,
|
| 26 |
+
pipeline: Any,
|
| 27 |
+
**pipe_args):
|
| 28 |
+
start_time = datetime.now()
|
| 29 |
+
for idx, seed in enumerate(seed_range):
|
| 30 |
+
args.seed = seed # update seed
|
| 31 |
+
print(f"\n-> [{idx}/{len(seed_range)}], "
|
| 32 |
+
f"current seed: {seed}, "
|
| 33 |
+
f"current time: {datetime.now() - start_time}\n")
|
| 34 |
+
pipe = pipeline(args)
|
| 35 |
+
pipe.painterly_rendering(**pipe_args)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main(args, seed_range):
|
| 39 |
+
args.batch_size = 1 # rendering one SVG at a time
|
| 40 |
+
|
| 41 |
+
args.width = float(args.width)
|
| 42 |
+
|
| 43 |
+
render_batch_fn = partial(render_batch_wrap, args=args, seed_range=seed_range)
|
| 44 |
+
|
| 45 |
+
if args.task == "diffsketcher": # text2sketch
|
| 46 |
+
from pipelines.painter.diffsketcher_pipeline import DiffSketcherPipeline
|
| 47 |
+
|
| 48 |
+
if not args.render_batch:
|
| 49 |
+
pipe = DiffSketcherPipeline(args)
|
| 50 |
+
pipe.painterly_rendering(args.prompt)
|
| 51 |
+
else: # generate many SVG at once
|
| 52 |
+
render_batch_fn(pipeline=DiffSketcherPipeline, prompt=args.prompt)
|
| 53 |
+
|
| 54 |
+
elif args.task == "style-diffsketcher": # text2sketch + style transfer
|
| 55 |
+
from pipelines.painter.diffsketcher_stylized_pipeline import StylizedDiffSketcherPipeline
|
| 56 |
+
|
| 57 |
+
if not args.render_batch:
|
| 58 |
+
pipe = StylizedDiffSketcherPipeline(args)
|
| 59 |
+
pipe.painterly_rendering(args.prompt, args.style_file)
|
| 60 |
+
else: # generate many SVG at once
|
| 61 |
+
render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=args.prompt, style_fpath=args.style_file)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == '__main__':
|
| 65 |
+
parser = argparse.ArgumentParser(
|
| 66 |
+
description="vary style and content painterly rendering",
|
| 67 |
+
parents=[accelerate_parser(), base_data_parser()]
|
| 68 |
+
)
|
| 69 |
+
# flag
|
| 70 |
+
parser.add_argument("-tk", "--task",
|
| 71 |
+
default="diffsketcher", type=str,
|
| 72 |
+
choices=['diffsketcher', 'style-diffsketcher'],
|
| 73 |
+
help="choose a method.")
|
| 74 |
+
# config
|
| 75 |
+
parser.add_argument("-c", "--config",
|
| 76 |
+
required=True, type=str,
|
| 77 |
+
default="",
|
| 78 |
+
help="YAML/YML file for configuration.")
|
| 79 |
+
parser.add_argument("-style", "--style_file",
|
| 80 |
+
default="", type=str,
|
| 81 |
+
help="the path of style img place.")
|
| 82 |
+
# prompt
|
| 83 |
+
parser.add_argument("-pt", "--prompt", default="A horse is drinking water by the lake", type=str)
|
| 84 |
+
parser.add_argument("-npt", "--negative_prompt", default="", type=str)
|
| 85 |
+
# DiffSVG
|
| 86 |
+
parser.add_argument("--print_timing", "-timing", action="store_true",
|
| 87 |
+
help="set print svg rendering timing.")
|
| 88 |
+
# diffuser
|
| 89 |
+
parser.add_argument("--download", action="store_true",
|
| 90 |
+
help="download models from huggingface automatically.")
|
| 91 |
+
parser.add_argument("--force_download", "-download", action="store_true",
|
| 92 |
+
help="force the models to be downloaded from huggingface.")
|
| 93 |
+
parser.add_argument("--resume_download", "-dpm_resume", action="store_true",
|
| 94 |
+
help="download the models again from the breakpoint.")
|
| 95 |
+
# rendering quantity
|
| 96 |
+
# like: python main.py -rdbz -srange 100 200
|
| 97 |
+
parser.add_argument("--render_batch", "-rdbz", action="store_true")
|
| 98 |
+
parser.add_argument("-srange", "--seed_range",
|
| 99 |
+
required=False, nargs='+',
|
| 100 |
+
help="Sampling quantity.")
|
| 101 |
+
# visual rendering process
|
| 102 |
+
parser.add_argument("-mv", "--make_video", action="store_true",
|
| 103 |
+
help="make a video of the rendering process.")
|
| 104 |
+
parser.add_argument("-frame_freq", "--video_frame_freq",
|
| 105 |
+
default=1, type=int,
|
| 106 |
+
help="video frame control.")
|
| 107 |
+
parser.add_argument("-framerate", "--video_frame_rate",
|
| 108 |
+
default=36, type=int,
|
| 109 |
+
help="by adjusting the frame rate, you can control the playback speed of the output video.")
|
| 110 |
+
|
| 111 |
+
args = parser.parse_args()
|
| 112 |
+
|
| 113 |
+
# set the random seed range
|
| 114 |
+
seed_range = None
|
| 115 |
+
if args.render_batch:
|
| 116 |
+
# random sampling without specifying a range
|
| 117 |
+
start_, end_ = 1, 1000000
|
| 118 |
+
if args.seed_range is not None: # specify range sequential sampling
|
| 119 |
+
seed_range_ = list(args.seed_range)
|
| 120 |
+
assert len(seed_range_) == 2 and int(seed_range_[1]) > int(seed_range_[0])
|
| 121 |
+
start_, end_ = int(seed_range_[0]), int(seed_range_[1])
|
| 122 |
+
seed_range = [i for i in range(start_, end_)]
|
| 123 |
+
else:
|
| 124 |
+
# a list of lengths 1000 sampled from the range start_ to end_ (e.g.: [1, 1000000])
|
| 125 |
+
numbers = list(range(start_, end_))
|
| 126 |
+
seed_range = random.sample(numbers, k=1000)
|
| 127 |
+
|
| 128 |
+
args = merge_and_update_config(args)
|
| 129 |
+
|
| 130 |
+
set_seed(args.seed)
|
| 131 |
+
main(args, seed_range)
|