Update generate.py
Browse files- generate.py +4 -4
generate.py
CHANGED
|
@@ -12,8 +12,8 @@ from .utils import prepare_control, load_latent, load_video, prepare_depth, save
|
|
| 12 |
from .pnp_utils import register_time, register_attention_control, register_conv_control
|
| 13 |
|
| 14 |
# will cause an issue
|
| 15 |
-
from . import vidtome
|
| 16 |
-
|
| 17 |
# suppress partial model loading warning
|
| 18 |
logging.set_verbosity_error()
|
| 19 |
|
|
@@ -95,7 +95,7 @@ class Generator(nn.Module):
|
|
| 95 |
self.pipe.load_lora_weights(**gene_config.lora)
|
| 96 |
|
| 97 |
def activate_vidtome(self):
|
| 98 |
-
|
| 99 |
seed = self.seed, batch_size = self.batch_size, align_batch = self.use_pnp or self.align_batch, global_rand = self.global_rand)
|
| 100 |
|
| 101 |
@torch.no_grad()
|
|
@@ -234,7 +234,7 @@ class Generator(nn.Module):
|
|
| 234 |
def post_iter(self, x, t):
|
| 235 |
if self.merge_global:
|
| 236 |
# Reset global tokens
|
| 237 |
-
|
| 238 |
|
| 239 |
@torch.no_grad()
|
| 240 |
def pred_noise(self, x, cond, t, batch_idx=None):
|
|
|
|
| 12 |
from .pnp_utils import register_time, register_attention_control, register_conv_control
|
| 13 |
|
| 14 |
# will cause an issue
|
| 15 |
+
# from . import vidtome
|
| 16 |
+
from .vidtome import update_patch, update_patch
|
| 17 |
# suppress partial model loading warning
|
| 18 |
logging.set_verbosity_error()
|
| 19 |
|
|
|
|
| 95 |
self.pipe.load_lora_weights(**gene_config.lora)
|
| 96 |
|
| 97 |
def activate_vidtome(self):
|
| 98 |
+
apply_patch(self.pipe, self.local_merge_ratio, self.merge_global, self.global_merge_ratio,
|
| 99 |
seed = self.seed, batch_size = self.batch_size, align_batch = self.use_pnp or self.align_batch, global_rand = self.global_rand)
|
| 100 |
|
| 101 |
@torch.no_grad()
|
|
|
|
| 234 |
def post_iter(self, x, t):
|
| 235 |
if self.merge_global:
|
| 236 |
# Reset global tokens
|
| 237 |
+
update_patch(self.pipe, global_tokens = None)
|
| 238 |
|
| 239 |
@torch.no_grad()
|
| 240 |
def pred_noise(self, x, cond, t, batch_idx=None):
|