Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # Author: ximing xing | |
| # Description: the main func of this project. | |
| # Copyright (c) 2023, XiMing Xing. | |
| import os | |
| import sys | |
| from functools import partial | |
| from accelerate.utils import set_seed | |
| import hydra | |
| import omegaconf | |
| sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0]) | |
| from pytorch_svgrender.utils import render_batch_wrap, get_seed_range | |
| METHODS = [ | |
| 'diffvg', | |
| 'live', | |
| 'vectorfusion', | |
| 'clipasso', | |
| 'clipascene', | |
| 'diffsketcher', | |
| 'stylediffsketcher', | |
| 'clipdraw', | |
| 'styleclipdraw', | |
| 'wordasimage', | |
| 'clipfont', | |
| 'svgdreamer' | |
| ] | |
| def main(cfg: omegaconf.DictConfig): | |
| # print(omegaconf.OmegaConf.to_yaml(cfg)) | |
| flag = cfg.x.method | |
| assert flag in METHODS, f"{flag} is not currently supported!" | |
| # seed prepare | |
| set_seed(cfg.seed) | |
| seed_range = get_seed_range(cfg.srange) if cfg.multirun else None | |
| # render function | |
| render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range) | |
| if flag == "diffvg": # img2svg | |
| from pytorch_svgrender.pipelines.DiffVG_pipeline import DiffVGPipeline | |
| pipe = DiffVGPipeline(cfg) | |
| pipe.painterly_rendering(cfg.target) | |
| elif flag == "live": # img2svg | |
| from pytorch_svgrender.pipelines.LIVE_pipeline import LIVEPipeline | |
| pipe = LIVEPipeline(cfg) | |
| pipe.painterly_rendering(cfg.target) | |
| elif flag == "vectorfusion": # text2svg | |
| from pytorch_svgrender.pipelines.VectorFusion_pipeline import VectorFusionPipeline | |
| if not cfg.multirun: | |
| pipe = VectorFusionPipeline(cfg) | |
| pipe.painterly_rendering(cfg.prompt) | |
| else: # generate many SVG at once | |
| render_batch_fn(pipeline=VectorFusionPipeline, text_prompt=cfg.prompt) | |
| elif flag == "svgdreamer": # text2svg | |
| from pytorch_svgrender.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline | |
| if not cfg.multirun: | |
| pipe = SVGDreamerPipeline(cfg) | |
| pipe.painterly_rendering(cfg.prompt) | |
| else: # generate many SVG at once | |
| render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None) | |
| elif flag == "wordasimage": # text2font | |
| from pytorch_svgrender.pipelines.WordAsImage_pipeline import WordAsImagePipeline | |
| pipe = WordAsImagePipeline(cfg) | |
| pipe.painterly_rendering(cfg.x.word, cfg.prompt, cfg.x.optim_letter) | |
| elif flag == "clipasso": # img2sketch | |
| from pytorch_svgrender.pipelines.CLIPasso_pipeline import CLIPassoPipeline | |
| pipe = CLIPassoPipeline(cfg) | |
| pipe.painterly_rendering(cfg.target) | |
| elif flag == 'clipascene': | |
| from pytorch_svgrender.pipelines.CLIPascene_pipeline import CLIPascenePipeline | |
| pipe = CLIPascenePipeline(cfg) | |
| pipe.painterly_rendering(cfg.target) | |
| elif flag == "clipdraw": # text2svg | |
| from pytorch_svgrender.pipelines.CLIPDraw_pipeline import CLIPDrawPipeline | |
| pipe = CLIPDrawPipeline(cfg) | |
| pipe.painterly_rendering(cfg.prompt) | |
| elif flag == "clipfont": # text and font to font | |
| from pytorch_svgrender.pipelines.CLIPFont_pipeline import CLIPFontPipeline | |
| if not cfg.multirun: | |
| pipe = CLIPFontPipeline(cfg) | |
| pipe.painterly_rendering(svg_path=cfg.target, prompt=cfg.prompt) | |
| else: # generate many SVG at once | |
| render_batch_fn(pipeline=CLIPFontPipeline, svg_path=cfg.target, prompt=cfg.prompt) | |
| elif flag == "styleclipdraw": # text to stylized svg | |
| from pytorch_svgrender.pipelines.StyleCLIPDraw_pipeline import StyleCLIPDrawPipeline | |
| pipe = StyleCLIPDrawPipeline(cfg) | |
| pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target) | |
| elif flag == "diffsketcher": # text2sketch | |
| from pytorch_svgrender.pipelines.DiffSketcher_pipeline import DiffSketcherPipeline | |
| if not cfg.multirun: | |
| pipe = DiffSketcherPipeline(cfg) | |
| pipe.painterly_rendering(cfg.prompt) | |
| else: # generate many SVG at once | |
| render_batch_fn(pipeline=DiffSketcherPipeline, prompt=cfg.prompt) | |
| elif flag == "stylediffsketcher": # text2sketch + style transfer | |
| from pytorch_svgrender.pipelines.DiffSketcher_stylized_pipeline import StylizedDiffSketcherPipeline | |
| if not cfg.multirun: | |
| pipe = StylizedDiffSketcherPipeline(cfg) | |
| pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target) | |
| else: # generate many SVG at once | |
| render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=cfg.prompt, style_fpath=cfg.style_file) | |
| if __name__ == '__main__': | |
| main() | |