Commit
·
9a6a590
1
Parent(s):
2057037
Upload folder using huggingface_hub
Browse files- animatediff/models/motion_module.py +6 -1
- train.py +12 -0
animatediff/models/motion_module.py
CHANGED
|
@@ -308,9 +308,14 @@ class VersatileAttention(CrossAttention):
|
|
| 308 |
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 309 |
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
# attention, what we cannot get enough of
|
| 312 |
if self._use_memory_efficient_attention_xformers:
|
| 313 |
-
|
|
|
|
| 314 |
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
| 315 |
hidden_states = hidden_states.to(query.dtype)
|
| 316 |
else:
|
|
|
|
| 308 |
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 309 |
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
| 310 |
|
| 311 |
+
if not hasattr(self, '_use_memory_efficient_attention_xformers'):
|
| 312 |
+
self._use_memory_efficient_attention_xformers = True
|
| 313 |
+
|
| 314 |
+
|
| 315 |
# attention, what we cannot get enough of
|
| 316 |
if self._use_memory_efficient_attention_xformers:
|
| 317 |
+
self.set_use_memory_efficient_attention_xformers(True)
|
| 318 |
+
# hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
| 319 |
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
| 320 |
hidden_states = hidden_states.to(query.dtype)
|
| 321 |
else:
|
train.py
CHANGED
|
@@ -177,6 +177,7 @@ def main(
|
|
| 177 |
for name, module in unet.named_modules():
|
| 178 |
if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
|
| 179 |
for params in module.parameters():
|
|
|
|
| 180 |
params.requires_grad = True
|
| 181 |
|
| 182 |
if enable_xformers_memory_efficient_attention:
|
|
@@ -370,10 +371,21 @@ def main(
|
|
| 370 |
avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
|
| 371 |
train_loss += avg_loss.item() / gradient_accumulation_steps
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
# Backpropagate
|
| 374 |
accelerator.backward(loss)
|
| 375 |
if accelerator.sync_gradients:
|
| 376 |
accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
optimizer.step()
|
| 378 |
lr_scheduler.step()
|
| 379 |
optimizer.zero_grad()
|
|
|
|
| 177 |
for name, module in unet.named_modules():
|
| 178 |
if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
|
| 179 |
for params in module.parameters():
|
| 180 |
+
print("trainable", name)
|
| 181 |
params.requires_grad = True
|
| 182 |
|
| 183 |
if enable_xformers_memory_efficient_attention:
|
|
|
|
| 371 |
avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
|
| 372 |
train_loss += avg_loss.item() / gradient_accumulation_steps
|
| 373 |
|
| 374 |
+
for name, module in unet.named_modules():
|
| 375 |
+
if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
|
| 376 |
+
for params in module.parameters():
|
| 377 |
+
params.requires_grad = True
|
| 378 |
+
|
| 379 |
# Backpropagate
|
| 380 |
accelerator.backward(loss)
|
| 381 |
if accelerator.sync_gradients:
|
| 382 |
accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
| 383 |
+
|
| 384 |
+
# for param in unet.parameters():
|
| 385 |
+
# print(param.grad)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
|
| 389 |
optimizer.step()
|
| 390 |
lr_scheduler.step()
|
| 391 |
optimizer.zero_grad()
|