fix(state): fix model state
Browse files- LICENSE +1 -1
- README.md +2 -0
- conf/config.yaml +4 -3
- script/install.sh +3 -3
- svgdreamer/libs/logging.py +1 -1
- svgdreamer/libs/model_state.py +45 -45
LICENSE
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
MIT License
|
| 2 |
|
| 3 |
-
Copyright (c)
|
| 4 |
|
| 5 |
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
of this software and associated documentation files (the "Software"), to deal
|
|
|
|
| 1 |
MIT License
|
| 2 |
|
| 3 |
+
Copyright (c) 2024 XiMing Xing
|
| 4 |
|
| 5 |
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
of this software and associated documentation files (the "Software"), to deal
|
README.md
CHANGED
|
@@ -201,6 +201,8 @@ python svgdreamer.py x=ink "prompt='Big Wild Goose Pagoda. ink style. Minimalist
|
|
| 201 |
#### More Cases
|
| 202 |
|
| 203 |
````shell
|
|
|
|
|
|
|
| 204 |
# Style: painting
|
| 205 |
python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" result_path='./logs/VanGogh-Portrait'
|
| 206 |
````
|
|
|
|
| 201 |
#### More Cases
|
| 202 |
|
| 203 |
````shell
|
| 204 |
+
python svgdreamer.py x=iconography "prompt='illustration of an New York City, in the style of propaganda poster, vivid colours, detailed, sunny day, attention to detail, 8k.'" result_path='./logs/NewYorkCity'
|
| 205 |
+
python svgdreamer.py x=iconography "prompt='A colorful German shepherd in vector art. tending on artstation.'" result_path='./logs/GermanShepherd'
|
| 206 |
# Style: painting
|
| 207 |
python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" result_path='./logs/VanGogh-Portrait'
|
| 208 |
````
|
conf/config.yaml
CHANGED
|
@@ -11,7 +11,7 @@ skip_sive: True # optimize from scratch without SIVE init
|
|
| 11 |
# Accelerate config
|
| 12 |
state:
|
| 13 |
cpu: False # use cpu
|
| 14 |
-
mprec: no # mixed precision, choices: 'no', 'fp16', 'bf16'
|
| 15 |
|
| 16 |
# Diffusers config
|
| 17 |
diffuser:
|
|
@@ -46,9 +46,10 @@ hydra:
|
|
| 46 |
run:
|
| 47 |
# output directory for normal runs
|
| 48 |
# warning: make sure that the L53-55 of './libs/model_state.py' and 'dir' are modified together
|
| 49 |
-
dir: ./${result_path}/SVGDreamer-${now:%Y-%m-%
|
| 50 |
|
| 51 |
# default settings
|
| 52 |
defaults:
|
| 53 |
- _self_
|
| 54 |
-
- x: ~
|
|
|
|
|
|
| 11 |
# Accelerate config
|
| 12 |
state:
|
| 13 |
cpu: False # use cpu
|
| 14 |
+
mprec: 'no' # mixed precision, choices: 'no', 'fp16', 'bf16'
|
| 15 |
|
| 16 |
# Diffusers config
|
| 17 |
diffuser:
|
|
|
|
| 46 |
run:
|
| 47 |
# output directory for normal runs
|
| 48 |
# warning: make sure that the L53-55 of './libs/model_state.py' and 'dir' are modified together
|
| 49 |
+
dir: ./${result_path}/SVGDreamer-${now:%Y-%m-%d_%H-%M-%S}
|
| 50 |
|
| 51 |
# default settings
|
| 52 |
defaults:
|
| 53 |
- _self_
|
| 54 |
+
- x: ~
|
| 55 |
+
- override hydra/job_logging: disabled # Outputs only to stdout (no log file)
|
script/install.sh
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
eval "$(conda shell.bash hook)"
|
| 3 |
|
| 4 |
-
conda create --name svgrender python=3.10
|
| 5 |
conda activate svgrender
|
| 6 |
|
| 7 |
echo "The conda environment was successfully created"
|
| 8 |
|
| 9 |
-
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
| 10 |
|
| 11 |
echo "Pytorch installation is complete. version: 1.12.1"
|
| 12 |
|
|
@@ -30,7 +30,7 @@ pip install diffusers==0.20.2
|
|
| 30 |
echo "Diffusers installation is complete. version: 0.20.2"
|
| 31 |
# if xformers doesnt install properly with conda try installing with pip using the code below
|
| 32 |
# pip install --pre -U xformers
|
| 33 |
-
conda install xformers -c xformers
|
| 34 |
|
| 35 |
echo "xformers installation is complete."
|
| 36 |
|
|
|
|
| 1 |
#!/bin/bash
|
| 2 |
eval "$(conda shell.bash hook)"
|
| 3 |
|
| 4 |
+
conda create --name svgrender python=3.10 --yes
|
| 5 |
conda activate svgrender
|
| 6 |
|
| 7 |
echo "The conda environment was successfully created"
|
| 8 |
|
| 9 |
+
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch --yes
|
| 10 |
|
| 11 |
echo "Pytorch installation is complete. version: 1.12.1"
|
| 12 |
|
|
|
|
| 30 |
echo "Diffusers installation is complete. version: 0.20.2"
|
| 31 |
# if xformers doesnt install properly with conda try installing with pip using the code below
|
| 32 |
# pip install --pre -U xformers
|
| 33 |
+
conda install xformers -c xformers --yes
|
| 34 |
|
| 35 |
echo "xformers installation is complete."
|
| 36 |
|
svgdreamer/libs/logging.py
CHANGED
|
@@ -8,7 +8,7 @@ import sys
|
|
| 8 |
import errno
|
| 9 |
|
| 10 |
|
| 11 |
-
def
|
| 12 |
logger = PrintLogger(os.path.join(logs_dir, file_name))
|
| 13 |
sys.stdout = logger # record all python print
|
| 14 |
return logger
|
|
|
|
| 8 |
import errno
|
| 9 |
|
| 10 |
|
| 11 |
+
def build_sysout_print_logger(logs_dir: str, file_name: str = "log.txt"):
|
| 12 |
logger = PrintLogger(os.path.join(logs_dir, file_name))
|
| 13 |
sys.stdout = logger # record all python print
|
| 14 |
return logger
|
svgdreamer/libs/model_state.py
CHANGED
|
@@ -5,15 +5,14 @@
|
|
| 5 |
|
| 6 |
from typing import Union, List
|
| 7 |
from pathlib import Path
|
| 8 |
-
from datetime import datetime
|
| 9 |
-
import logging
|
| 10 |
|
| 11 |
-
|
|
|
|
| 12 |
from pprint import pprint
|
| 13 |
import torch
|
| 14 |
from accelerate import Accelerator
|
| 15 |
|
| 16 |
-
from .logging import
|
| 17 |
|
| 18 |
|
| 19 |
class ModelState:
|
|
@@ -31,35 +30,21 @@ class ModelState:
|
|
| 31 |
def __init__(
|
| 32 |
self,
|
| 33 |
args: DictConfig,
|
| 34 |
-
log_path_suffix: str
|
| 35 |
-
ignore_log=False, # whether to create log file or not
|
| 36 |
) -> None:
|
| 37 |
self.args: DictConfig = args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# set cfg
|
| 39 |
-
self.state_cfg = args.state
|
| 40 |
self.x_cfg = args.x
|
| 41 |
|
| 42 |
-
"""check valid"""
|
| 43 |
-
mixed_precision = self.state_cfg.get("mprec")
|
| 44 |
-
# Bug: omegaconf convert 'no' to false
|
| 45 |
-
mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision
|
| 46 |
-
|
| 47 |
"""create working space"""
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
results_folder = self.args.get("result_path", None)
|
| 52 |
-
if results_folder is None:
|
| 53 |
-
self.result_path = Path("./workdir") / f"SVGDreamer-{now_time}"
|
| 54 |
-
else:
|
| 55 |
-
self.result_path = Path(results_folder) / f"SVGDreamer-{now_time}"
|
| 56 |
-
|
| 57 |
-
# update result_path: ./runs/{method_name}-{exp_name}/{log_path_suffix}
|
| 58 |
-
# noting: can be understood as "results dir / methods / ablation study / your result"
|
| 59 |
-
if log_path_suffix is not None:
|
| 60 |
-
self.result_path = self.result_path / f"{log_path_suffix}"
|
| 61 |
-
else:
|
| 62 |
-
self.result_path = self.result_path / f"SVGDreamer"
|
| 63 |
|
| 64 |
"""init visualized tracker"""
|
| 65 |
# TODO: monitor with WANDB or TENSORBOARD
|
|
@@ -71,39 +56,30 @@ class ModelState:
|
|
| 71 |
|
| 72 |
"""HuggingFace Accelerator"""
|
| 73 |
self.accelerator = Accelerator(
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
cpu=True if self.state_cfg.cpu else False,
|
| 77 |
log_with=None if len(self.log_with) == 0 else self.log_with,
|
| 78 |
-
project_dir=self.
|
| 79 |
)
|
| 80 |
|
| 81 |
"""logs"""
|
| 82 |
if self.accelerator.is_local_main_process:
|
| 83 |
-
# logging
|
| 84 |
-
self.log = logging.getLogger(__name__)
|
| 85 |
-
|
| 86 |
-
# log results in a folder periodically
|
| 87 |
self.result_path.mkdir(parents=True, exist_ok=True)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
file_name=f"{now_time}-{args.seed}-log.txt"
|
| 92 |
-
)
|
| 93 |
|
| 94 |
print("==> system args: ")
|
| 95 |
-
|
|
|
|
| 96 |
print(sys_cfg)
|
| 97 |
print("==> yaml config args: ")
|
| 98 |
print(self.x_cfg)
|
| 99 |
|
| 100 |
print("\n***** Model State *****")
|
| 101 |
-
print(f"-> Mixed Precision: {
|
| 102 |
print(f"-> Weight dtype: {self.weight_dtype}")
|
| 103 |
|
| 104 |
-
if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled:
|
| 105 |
-
print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}")
|
| 106 |
-
|
| 107 |
print(f"-> Working Space: '{self.result_path}'")
|
| 108 |
|
| 109 |
"""glob step"""
|
|
@@ -251,3 +227,27 @@ class ModelState:
|
|
| 251 |
if len(self.log_with) > 0:
|
| 252 |
self.close_tracker()
|
| 253 |
self.print(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from typing import Union, List
|
| 7 |
from pathlib import Path
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
import hydra
|
| 10 |
+
from omegaconf import OmegaConf, DictConfig, open_dict
|
| 11 |
from pprint import pprint
|
| 12 |
import torch
|
| 13 |
from accelerate import Accelerator
|
| 14 |
|
| 15 |
+
from .logging import build_sysout_print_logger
|
| 16 |
|
| 17 |
|
| 18 |
class ModelState:
|
|
|
|
| 30 |
def __init__(
|
| 31 |
self,
|
| 32 |
args: DictConfig,
|
| 33 |
+
log_path_suffix: str,
|
|
|
|
| 34 |
) -> None:
|
| 35 |
self.args: DictConfig = args
|
| 36 |
+
|
| 37 |
+
# runtime output directory
|
| 38 |
+
with open_dict(args):
|
| 39 |
+
args.output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
|
| 40 |
+
|
| 41 |
# set cfg
|
|
|
|
| 42 |
self.x_cfg = args.x
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
"""create working space"""
|
| 45 |
+
self.result_path = Path(args.output_dir) # saving path
|
| 46 |
+
self.monitor_dir = self.result_path / 'runs' # monitor path
|
| 47 |
+
self.result_path = self.result_path / f"{log_path_suffix}" # method results path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
"""init visualized tracker"""
|
| 50 |
# TODO: monitor with WANDB or TENSORBOARD
|
|
|
|
| 56 |
|
| 57 |
"""HuggingFace Accelerator"""
|
| 58 |
self.accelerator = Accelerator(
|
| 59 |
+
mixed_precision=args.state.get("mprec"),
|
| 60 |
+
cpu=args.state.get('cpu', False),
|
|
|
|
| 61 |
log_with=None if len(self.log_with) == 0 else self.log_with,
|
| 62 |
+
project_dir=self.monitor_dir,
|
| 63 |
)
|
| 64 |
|
| 65 |
"""logs"""
|
| 66 |
if self.accelerator.is_local_main_process:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
self.result_path.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
# system print recorder
|
| 69 |
+
self.logger = build_sysout_print_logger(logs_dir=self.result_path.as_posix(),
|
| 70 |
+
file_name=f"stdout-print-log.txt")
|
|
|
|
|
|
|
| 71 |
|
| 72 |
print("==> system args: ")
|
| 73 |
+
custom_cfg = OmegaConf.masked_copy(args, ["x"])
|
| 74 |
+
sys_cfg = dictconfig_diff(args, custom_cfg)
|
| 75 |
print(sys_cfg)
|
| 76 |
print("==> yaml config args: ")
|
| 77 |
print(self.x_cfg)
|
| 78 |
|
| 79 |
print("\n***** Model State *****")
|
| 80 |
+
print(f"-> Mixed Precision: {self.accelerator.state.mixed_precision}")
|
| 81 |
print(f"-> Weight dtype: {self.weight_dtype}")
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
print(f"-> Working Space: '{self.result_path}'")
|
| 84 |
|
| 85 |
"""glob step"""
|
|
|
|
| 227 |
if len(self.log_with) > 0:
|
| 228 |
self.close_tracker()
|
| 229 |
self.print(msg)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def dictconfig_diff(dict1, dict2):
|
| 233 |
+
"""
|
| 234 |
+
Find the difference between two OmegaConf.DictConfig objects
|
| 235 |
+
"""
|
| 236 |
+
# Convert OmegaConf.DictConfig to regular dictionaries
|
| 237 |
+
dict1 = OmegaConf.to_container(dict1, resolve=True)
|
| 238 |
+
dict2 = OmegaConf.to_container(dict2, resolve=True)
|
| 239 |
+
|
| 240 |
+
# Find the keys that are in dict1 but not in dict2
|
| 241 |
+
diff = {}
|
| 242 |
+
for key in dict1:
|
| 243 |
+
if key not in dict2:
|
| 244 |
+
diff[key] = dict1[key]
|
| 245 |
+
elif dict1[key] != dict2[key]:
|
| 246 |
+
if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
|
| 247 |
+
nested_diff = dictconfig_diff(dict1[key], dict2[key])
|
| 248 |
+
if nested_diff:
|
| 249 |
+
diff[key] = nested_diff
|
| 250 |
+
else:
|
| 251 |
+
diff[key] = dict1[key]
|
| 252 |
+
|
| 253 |
+
return diff
|