fix(method): if attn_map is None
Browse files- Examples.md +7 -1
- README.md +1 -1
- conf/x/iconography.yaml +2 -2
- svgdreamer/painter/__init__.py +2 -2
- svgdreamer/pipelines/SVGDreamer_pipeline.py +73 -40
- svgdreamer/token2attn/attn_control.py +1 -1
Examples.md
CHANGED
|
@@ -160,4 +160,10 @@ expressive eyes. <br/>
|
|
| 160 |
|
| 161 |
````shell
|
| 162 |
python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" x.num_paths=256 result_path='./logs/VanGogh-Portrait'
|
| 163 |
-
````
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
````shell
|
| 162 |
python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" x.num_paths=256 result_path='./logs/VanGogh-Portrait'
|
| 163 |
+
````
|
| 164 |
+
|
| 165 |
+
### Case: planet Saturn
|
| 166 |
+
|
| 167 |
+
```shell
|
| 168 |
+
python svgdreamer.py x=iconography-s1 skip_sive=False "prompt='An icon of the planet Saturn. minimal flat 2D vector icon. plain color background. trending on ArtStation.'" token_ind=6 x.sive.bg.num_iter=50 x.sive.fg.num_iter=50 x.vpsd.t_schedule='randint' result_path='./logs/Saturn' multirun=True state.mprec='fp16
|
| 169 |
+
```
|
README.md
CHANGED
|
@@ -80,7 +80,7 @@ realistic <br/>
|
|
| 80 |
**Script:**
|
| 81 |
|
| 82 |
```shell
|
| 83 |
-
python svgdreamer.py x=iconography skip_sive=False "prompt='an image of Batman. full body action pose, complete detailed body. white background. empty background, high quality, 4K, ultra realistic'" token_ind=4 x.
|
| 84 |
```
|
| 85 |
|
| 86 |
🔹Parameter:
|
|
|
|
| 80 |
**Script:**
|
| 81 |
|
| 82 |
```shell
|
| 83 |
+
python svgdreamer.py x=iconography skip_sive=False "prompt='an image of Batman. full body action pose, complete detailed body. white background. empty background, high quality, 4K, ultra realistic'" token_ind=4 x.vpsd.t_schedule='randint' result_path='./logs/batman' multirun=True
|
| 84 |
```
|
| 85 |
|
| 86 |
🔹Parameter:
|
conf/x/iconography.yaml
CHANGED
|
@@ -41,7 +41,7 @@ sive:
|
|
| 41 |
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
| 42 |
bg:
|
| 43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
| 44 |
-
num_iter:
|
| 45 |
num_paths: 256
|
| 46 |
path_schedule: 'repeat' # 'repeat', 'list'
|
| 47 |
schedule_each: 128
|
|
@@ -61,7 +61,7 @@ sive:
|
|
| 61 |
xing_loss_weight: 0.001
|
| 62 |
fg:
|
| 63 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
| 64 |
-
num_iter:
|
| 65 |
num_paths: 256 # number of strokes
|
| 66 |
path_schedule: 'repeat' # 'repeat', 'list'
|
| 67 |
schedule_each: 128
|
|
|
|
| 41 |
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
| 42 |
bg:
|
| 43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
| 44 |
+
num_iter: 50
|
| 45 |
num_paths: 256
|
| 46 |
path_schedule: 'repeat' # 'repeat', 'list'
|
| 47 |
schedule_each: 128
|
|
|
|
| 61 |
xing_loss_weight: 0.001
|
| 62 |
fg:
|
| 63 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
| 64 |
+
num_iter: 50
|
| 65 |
num_paths: 256 # number of strokes
|
| 66 |
path_schedule: 'repeat' # 'repeat', 'list'
|
| 67 |
schedule_each: 128
|
svgdreamer/painter/__init__.py
CHANGED
|
@@ -2,8 +2,8 @@
|
|
| 2 |
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
# Description:
|
| 4 |
|
| 5 |
-
from .painter_params import
|
| 6 |
-
|
| 7 |
from .component_painter_params import CompPainter, CompPainterOptimizer
|
| 8 |
from .loss import xing_loss_fn
|
| 9 |
from .VPSD_pipeline import VectorizedParticleSDSPipeline
|
|
|
|
| 2 |
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
# Description:
|
| 4 |
|
| 5 |
+
from .painter_params import Painter, PainterOptimizer, CosineWithWarmupLRLambda, RandomCoordInit, NaiveCoordInit, \
|
| 6 |
+
SparseCoordInit, get_sdf
|
| 7 |
from .component_painter_params import CompPainter, CompPainterOptimizer
|
| 8 |
from .loss import xing_loss_fn
|
| 9 |
from .VPSD_pipeline import VectorizedParticleSDSPipeline
|
svgdreamer/pipelines/SVGDreamer_pipeline.py
CHANGED
|
@@ -20,8 +20,8 @@ from torchvision import transforms
|
|
| 20 |
from skimage.color import rgb2gray
|
| 21 |
|
| 22 |
from svgdreamer.libs import ModelState, get_optimizer
|
| 23 |
-
from svgdreamer.painter import
|
| 24 |
-
|
| 25 |
from svgdreamer.token2attn.attn_control import EmptyControl, AttentionStore
|
| 26 |
from svgdreamer.token2attn.ptp_utils import view_images
|
| 27 |
from svgdreamer.utils.plot import plot_img, plot_couple, plot_attn, save_image
|
|
@@ -38,8 +38,10 @@ class SVGDreamerPipeline(ModelState):
|
|
| 38 |
# assert
|
| 39 |
assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"]
|
| 40 |
args.skip_sive = True if args.x.style in ["pixelart", "low-poly"] else args.skip_sive
|
| 41 |
-
assert args.x.vpsd.n_particle >= args.x.vpsd.vsd_n_particle
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
assert args.x.vpsd.n_phi_sample >= 1
|
| 44 |
|
| 45 |
logdir_ = f"sd{args.seed}" \
|
|
@@ -123,15 +125,26 @@ class SVGDreamerPipeline(ModelState):
|
|
| 123 |
self.close(msg="painterly rendering complete.")
|
| 124 |
|
| 125 |
def SIVE_stage(self, text_prompt: str):
|
| 126 |
-
#
|
| 127 |
pipeline = DiffusionPipeline(self.x_cfg.sive_model_cfg, self.args.diffuser, self.device)
|
| 128 |
|
| 129 |
merged_svg_paths = []
|
| 130 |
merged_images = []
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
self.x_cfg.sive_model_cfg,
|
| 136 |
pipeline,
|
| 137 |
text_prompt,
|
|
@@ -139,18 +152,18 @@ class SVGDreamerPipeline(ModelState):
|
|
| 139 |
self.sive_cfg.attn_cfg,
|
| 140 |
self.im_size,
|
| 141 |
self.args.token_ind)
|
| 142 |
-
#
|
| 143 |
select_img = self.target_file_preprocess(select_sample_path.as_posix())
|
| 144 |
self.print(f"load target file from: {select_sample_path.as_posix()}")
|
| 145 |
|
| 146 |
-
#
|
| 147 |
-
fg_img, bg_img, fg_mask, bg_mask = self.extract_object(
|
| 148 |
tau=self.sive_cfg.mask_tau)
|
| 149 |
-
self.print(f"fg_img shape: {fg_img.shape}, bg_img: {bg_img.shape}")
|
| 150 |
|
| 151 |
-
#
|
| 152 |
-
self.print(f"->
|
| 153 |
-
bg_render_path = self.component_rendering(tag=f'{
|
| 154 |
prompt=text_prompt,
|
| 155 |
target_img=bg_img,
|
| 156 |
mask=bg_mask,
|
|
@@ -160,9 +173,14 @@ class SVGDreamerPipeline(ModelState):
|
|
| 160 |
optim_cfg=self.sive_optim,
|
| 161 |
log_png_dir=self.bg_png_logs_dir,
|
| 162 |
log_svg_dir=self.bg_svg_logs_dir)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
prompt=text_prompt,
|
| 167 |
target_img=fg_img,
|
| 168 |
mask=fg_mask,
|
|
@@ -172,8 +190,16 @@ class SVGDreamerPipeline(ModelState):
|
|
| 172 |
optim_cfg=self.sive_optim,
|
| 173 |
log_png_dir=self.fg_png_logs_dir,
|
| 174 |
log_svg_dir=self.fg_svg_logs_dir)
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
merge_svg_files(
|
| 178 |
svg_path_1=bg_render_path,
|
| 179 |
svg_path_2=fg_render_path,
|
|
@@ -182,11 +208,11 @@ class SVGDreamerPipeline(ModelState):
|
|
| 182 |
out_size=(self.im_size, self.im_size)
|
| 183 |
)
|
| 184 |
|
| 185 |
-
#
|
| 186 |
# Note: you are not allowed to add further paths here
|
| 187 |
if self.sive_cfg.tog.reinit:
|
| 188 |
-
self.print("->
|
| 189 |
-
merged_svg_path = self.refine_rendering(tag=f'{
|
| 190 |
prompt=text_prompt,
|
| 191 |
target_img=select_img,
|
| 192 |
canvas_size=(self.im_size, self.im_size),
|
|
@@ -194,22 +220,21 @@ class SVGDreamerPipeline(ModelState):
|
|
| 194 |
optim_cfg=self.sive_optim,
|
| 195 |
init_svg_path=merged_svg_path)
|
| 196 |
|
| 197 |
-
# svg-to-png
|
| 198 |
-
merged_png_path = self.result_path / f'SIVE_render_final_{
|
| 199 |
cairosvg.svg2png(url=merged_svg_path.as_posix(), write_to=merged_png_path.as_posix())
|
| 200 |
-
|
| 201 |
-
# collect paths
|
| 202 |
-
merged_svg_paths.append(merged_svg_path)
|
| 203 |
merged_images.append(self.target_file_preprocess(merged_png_path))
|
| 204 |
-
|
|
|
|
| 205 |
controller.reset()
|
| 206 |
|
| 207 |
-
self.print(f"Vector Particle {
|
| 208 |
|
| 209 |
-
#
|
| 210 |
del pipeline
|
| 211 |
torch.cuda.empty_cache()
|
| 212 |
-
#
|
| 213 |
self.x_cfg.num_paths = self.sive_cfg.bg.num_paths + self.sive_cfg.fg.num_paths
|
| 214 |
|
| 215 |
return merged_svg_paths, merged_images
|
|
@@ -257,6 +282,9 @@ class SVGDreamerPipeline(ModelState):
|
|
| 257 |
if attention_map is not None:
|
| 258 |
# init fist control points by attention_map
|
| 259 |
attn_thresh, select_inds = renderer.attn_init_points(num_paths=sum(path_schedule), mask=mask)
|
|
|
|
|
|
|
|
|
|
| 260 |
# log attention, just once
|
| 261 |
plot_attn(attention_map, attn_thresh, target_img, select_inds,
|
| 262 |
(self.sive_attn_dir / f"attention_{tag}_map.jpg").as_posix())
|
|
@@ -381,14 +409,16 @@ class SVGDreamerPipeline(ModelState):
|
|
| 381 |
plot_img(img, self.refine_dir, fname=f"{tag}_before_refined")
|
| 382 |
|
| 383 |
n_iter = render_cfg.num_iter
|
|
|
|
|
|
|
| 384 |
# build painter optimizer
|
| 385 |
optimizer = CompPainterOptimizer(content_renderer, self.style, n_iter, optim_cfg)
|
| 386 |
# init optimizer
|
| 387 |
optimizer.init_optimizers()
|
| 388 |
|
| 389 |
-
print(f"=> n_point: {len(content_renderer.get_point_params())}, "
|
| 390 |
-
|
| 391 |
-
|
| 392 |
|
| 393 |
step = 0
|
| 394 |
with tqdm(initial=step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar:
|
|
@@ -434,7 +464,8 @@ class SVGDreamerPipeline(ModelState):
|
|
| 434 |
text_prompt: AnyStr,
|
| 435 |
init_svg_path: Union[List[AnyPath], AnyPath] = None,
|
| 436 |
init_image: Union[List[torch.Tensor], torch.Tensor] = None):
|
| 437 |
-
|
|
|
|
| 438 |
return
|
| 439 |
|
| 440 |
# for convenience
|
|
@@ -784,10 +815,12 @@ class SVGDreamerPipeline(ModelState):
|
|
| 784 |
generator=self.g_device)
|
| 785 |
outputs_np = [np.array(img) for img in outputs.images]
|
| 786 |
view_images(outputs_np, save_image=True, fp=gen_sample_path)
|
| 787 |
-
self.print(f"select_sample shape: {outputs_np[0].shape}")
|
| 788 |
|
| 789 |
if attn_init:
|
| 790 |
-
"
|
|
|
|
|
|
|
| 791 |
cross_attention_maps, tokens = \
|
| 792 |
pipeline.get_cross_attention([prompts],
|
| 793 |
controller,
|
|
@@ -862,7 +895,7 @@ class SVGDreamerPipeline(ModelState):
|
|
| 862 |
view_images(reversed_attn_map_vis, save_image=True,
|
| 863 |
fp=self.sive_attn_dir / f'reversed-fusion-attn-{iter}.png')
|
| 864 |
|
| 865 |
-
self.print(f"-> fusion attn_map: {attn_map.shape}")
|
| 866 |
else:
|
| 867 |
attn_map = None
|
| 868 |
inverse_attn = None
|
|
|
|
| 20 |
from skimage.color import rgb2gray
|
| 21 |
|
| 22 |
from svgdreamer.libs import ModelState, get_optimizer
|
| 23 |
+
from svgdreamer.painter import CompPainter, CompPainterOptimizer, xing_loss_fn, Painter, PainterOptimizer, \
|
| 24 |
+
CosineWithWarmupLRLambda, VectorizedParticleSDSPipeline, DiffusionPipeline
|
| 25 |
from svgdreamer.token2attn.attn_control import EmptyControl, AttentionStore
|
| 26 |
from svgdreamer.token2attn.ptp_utils import view_images
|
| 27 |
from svgdreamer.utils.plot import plot_img, plot_couple, plot_attn, save_image
|
|
|
|
| 38 |
# assert
|
| 39 |
assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"]
|
| 40 |
args.skip_sive = True if args.x.style in ["pixelart", "low-poly"] else args.skip_sive
|
| 41 |
+
# assert args.x.vpsd.n_particle >= args.x.vpsd.vsd_n_particle
|
| 42 |
+
if args.x.vpsd.vsd_n_particle > args.x.vpsd.n_particle: args.x.vpsd.vsd_n_particle = args.x.vpsd.n_particle
|
| 43 |
+
# assert args.x.vpsd.n_particle >= args.x.vpsd.phi_n_particle
|
| 44 |
+
if args.x.vpsd.phi_n_particle > args.x.vpsd.n_particle: args.x.vpsd.phi_n_particle = args.x.vpsd.n_particle
|
| 45 |
assert args.x.vpsd.n_phi_sample >= 1
|
| 46 |
|
| 47 |
logdir_ = f"sd{args.seed}" \
|
|
|
|
| 125 |
self.close(msg="painterly rendering complete.")
|
| 126 |
|
| 127 |
def SIVE_stage(self, text_prompt: str):
|
| 128 |
+
# Init diffusion model
|
| 129 |
pipeline = DiffusionPipeline(self.x_cfg.sive_model_cfg, self.args.diffuser, self.device)
|
| 130 |
|
| 131 |
merged_svg_paths = []
|
| 132 |
merged_images = []
|
| 133 |
+
|
| 134 |
+
successful_particles = 0
|
| 135 |
+
cur_idx = 0
|
| 136 |
+
|
| 137 |
+
while successful_particles < self.vpsd_cfg.n_particle:
|
| 138 |
+
if cur_idx >= self.vpsd_cfg.n_particle + 10: # max attempts
|
| 139 |
+
self.print(f"Reached maximum attempts ({cur_idx}). "
|
| 140 |
+
f"Only processed {successful_particles} particles successfully.")
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
self.print(f"Processing particle {cur_idx} "
|
| 144 |
+
f"(successful so far: {successful_particles}/{self.vpsd_cfg.n_particle})")
|
| 145 |
+
select_sample_path = self.result_path / f'select_sample_{cur_idx}.png'
|
| 146 |
+
# Generate sample and attention map
|
| 147 |
+
fg_attn_map, bg_attn_map, controller = self.extract_ldm_attn(cur_idx,
|
| 148 |
self.x_cfg.sive_model_cfg,
|
| 149 |
pipeline,
|
| 150 |
text_prompt,
|
|
|
|
| 152 |
self.sive_cfg.attn_cfg,
|
| 153 |
self.im_size,
|
| 154 |
self.args.token_ind)
|
| 155 |
+
# Load selected file
|
| 156 |
select_img = self.target_file_preprocess(select_sample_path.as_posix())
|
| 157 |
self.print(f"load target file from: {select_sample_path.as_posix()}")
|
| 158 |
|
| 159 |
+
# Get objects by attention map
|
| 160 |
+
fg_img, bg_img, fg_mask, bg_mask = self.extract_object(cur_idx, select_img, fg_attn_map, bg_attn_map,
|
| 161 |
tau=self.sive_cfg.mask_tau)
|
| 162 |
+
# self.print(f"fg_img shape: {fg_img.shape}, bg_img: {bg_img.shape}")
|
| 163 |
|
| 164 |
+
# Background rendering
|
| 165 |
+
self.print(f"-> Background rendering: ")
|
| 166 |
+
bg_render_path = self.component_rendering(tag=f'{cur_idx}_bg',
|
| 167 |
prompt=text_prompt,
|
| 168 |
target_img=bg_img,
|
| 169 |
mask=bg_mask,
|
|
|
|
| 173 |
optim_cfg=self.sive_optim,
|
| 174 |
log_png_dir=self.bg_png_logs_dir,
|
| 175 |
log_svg_dir=self.bg_svg_logs_dir)
|
| 176 |
+
if bg_render_path == 0:
|
| 177 |
+
self.print(f"Background rendering failed for particle {cur_idx}, trying next particle")
|
| 178 |
+
cur_idx += 1
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
# Foreground rendering
|
| 182 |
+
self.print(f"-> Foreground rendering: ")
|
| 183 |
+
fg_render_path = self.component_rendering(tag=f'{cur_idx}_fg',
|
| 184 |
prompt=text_prompt,
|
| 185 |
target_img=fg_img,
|
| 186 |
mask=fg_mask,
|
|
|
|
| 190 |
optim_cfg=self.sive_optim,
|
| 191 |
log_png_dir=self.fg_png_logs_dir,
|
| 192 |
log_svg_dir=self.fg_svg_logs_dir)
|
| 193 |
+
if fg_render_path == 0:
|
| 194 |
+
self.print(f"Foreground rendering failed for particle {cur_idx}, trying next particle")
|
| 195 |
+
cur_idx += 1
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
successful_particles += 1
|
| 199 |
+
cur_idx += 1
|
| 200 |
+
|
| 201 |
+
# Merge foreground and background
|
| 202 |
+
merged_svg_path = self.result_path / f'SIVE_render_final_{cur_idx}.svg'
|
| 203 |
merge_svg_files(
|
| 204 |
svg_path_1=bg_render_path,
|
| 205 |
svg_path_2=fg_render_path,
|
|
|
|
| 208 |
out_size=(self.im_size, self.im_size)
|
| 209 |
)
|
| 210 |
|
| 211 |
+
# Foreground and background refinement
|
| 212 |
# Note: you are not allowed to add further paths here
|
| 213 |
if self.sive_cfg.tog.reinit:
|
| 214 |
+
self.print("-> Enable vector graphic refinement:")
|
| 215 |
+
merged_svg_path = self.refine_rendering(tag=f'{cur_idx}_refine',
|
| 216 |
prompt=text_prompt,
|
| 217 |
target_img=select_img,
|
| 218 |
canvas_size=(self.im_size, self.im_size),
|
|
|
|
| 220 |
optim_cfg=self.sive_optim,
|
| 221 |
init_svg_path=merged_svg_path)
|
| 222 |
|
| 223 |
+
# Postprocess: svg-to-png & to tensor
|
| 224 |
+
merged_png_path = self.result_path / f'SIVE_render_final_{cur_idx}.png'
|
| 225 |
cairosvg.svg2png(url=merged_svg_path.as_posix(), write_to=merged_png_path.as_posix())
|
| 226 |
+
merged_svg_paths.append(merged_svg_path) # collect paths
|
|
|
|
|
|
|
| 227 |
merged_images.append(self.target_file_preprocess(merged_png_path))
|
| 228 |
+
|
| 229 |
+
# Clear attention recorder
|
| 230 |
controller.reset()
|
| 231 |
|
| 232 |
+
self.print(f"Vector Particle {cur_idx} Rendering End...\n")
|
| 233 |
|
| 234 |
+
# Free the VRAM
|
| 235 |
del pipeline
|
| 236 |
torch.cuda.empty_cache()
|
| 237 |
+
# Update paths
|
| 238 |
self.x_cfg.num_paths = self.sive_cfg.bg.num_paths + self.sive_cfg.fg.num_paths
|
| 239 |
|
| 240 |
return merged_svg_paths, merged_images
|
|
|
|
| 282 |
if attention_map is not None:
|
| 283 |
# init fist control points by attention_map
|
| 284 |
attn_thresh, select_inds = renderer.attn_init_points(num_paths=sum(path_schedule), mask=mask)
|
| 285 |
+
# Warning: attention map failure
|
| 286 |
+
if len(select_inds) == 0: return 0
|
| 287 |
+
|
| 288 |
# log attention, just once
|
| 289 |
plot_attn(attention_map, attn_thresh, target_img, select_inds,
|
| 290 |
(self.sive_attn_dir / f"attention_{tag}_map.jpg").as_posix())
|
|
|
|
| 409 |
plot_img(img, self.refine_dir, fname=f"{tag}_before_refined")
|
| 410 |
|
| 411 |
n_iter = render_cfg.num_iter
|
| 412 |
+
self.print(f"Total iters: {n_iter}")
|
| 413 |
+
|
| 414 |
# build painter optimizer
|
| 415 |
optimizer = CompPainterOptimizer(content_renderer, self.style, n_iter, optim_cfg)
|
| 416 |
# init optimizer
|
| 417 |
optimizer.init_optimizers()
|
| 418 |
|
| 419 |
+
self.print(f"=> n_point: {len(content_renderer.get_point_params())}, "
|
| 420 |
+
f"n_width: {len(content_renderer.get_width_params())}, "
|
| 421 |
+
f"n_color: {len(content_renderer.get_color_params())}")
|
| 422 |
|
| 423 |
step = 0
|
| 424 |
with tqdm(initial=step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar:
|
|
|
|
| 464 |
text_prompt: AnyStr,
|
| 465 |
init_svg_path: Union[List[AnyPath], AnyPath] = None,
|
| 466 |
init_image: Union[List[torch.Tensor], torch.Tensor] = None):
|
| 467 |
+
# print(f"self.vpsd_cfg.use: {self.vpsd_cfg.use}")
|
| 468 |
+
if self.vpsd_cfg.use is False:
|
| 469 |
return
|
| 470 |
|
| 471 |
# for convenience
|
|
|
|
| 815 |
generator=self.g_device)
|
| 816 |
outputs_np = [np.array(img) for img in outputs.images]
|
| 817 |
view_images(outputs_np, save_image=True, fp=gen_sample_path)
|
| 818 |
+
# self.print(f"select_sample shape: {outputs_np[0].shape}")
|
| 819 |
|
| 820 |
if attn_init:
|
| 821 |
+
self.print(f"\nLDM attn-map logging:")
|
| 822 |
+
|
| 823 |
+
# Cross-attention map
|
| 824 |
cross_attention_maps, tokens = \
|
| 825 |
pipeline.get_cross_attention([prompts],
|
| 826 |
controller,
|
|
|
|
| 895 |
view_images(reversed_attn_map_vis, save_image=True,
|
| 896 |
fp=self.sive_attn_dir / f'reversed-fusion-attn-{iter}.png')
|
| 897 |
|
| 898 |
+
self.print(f"-> fusion attn_map: {attn_map.shape} \n")
|
| 899 |
else:
|
| 900 |
attn_map = None
|
| 901 |
inverse_attn = None
|
svgdreamer/token2attn/attn_control.py
CHANGED
|
@@ -85,7 +85,7 @@ class AttentionStore(AttentionControl):
|
|
| 85 |
self.step_store = self.get_empty_store()
|
| 86 |
|
| 87 |
def get_average_attention(self):
|
| 88 |
-
print(f"step count: {self.cur_step}")
|
| 89 |
average_attention = {
|
| 90 |
key: [item / self.cur_step for item in self.attention_store[key]]
|
| 91 |
for key in self.attention_store
|
|
|
|
| 85 |
self.step_store = self.get_empty_store()
|
| 86 |
|
| 87 |
def get_average_attention(self):
|
| 88 |
+
# print(f"step count: {self.cur_step}")
|
| 89 |
average_attention = {
|
| 90 |
key: [item / self.cur_step for item in self.attention_store[key]]
|
| 91 |
for key in self.attention_store
|