Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -741,6 +741,50 @@ class StreamMultiDiffusion(nn.Module):
|
|
| 741 |
self.ready_checklist['layers_ready'] = True
|
| 742 |
self.ready_checklist['flushed'] = False
|
| 743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
@torch.no_grad()
|
| 745 |
def update_single_layer(
|
| 746 |
self,
|
|
|
|
| 741 |
self.ready_checklist['layers_ready'] = True
|
| 742 |
self.ready_checklist['flushed'] = False
|
| 743 |
|
| 744 |
+
@torch.no_grad()
|
| 745 |
+
def update_masks(
|
| 746 |
+
self,
|
| 747 |
+
masks: Optional[Union[torch.Tensor, Image.Image, List[Image.Image]]] = None,
|
| 748 |
+
mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
| 749 |
+
mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
| 750 |
+
) -> None:
|
| 751 |
+
if not self.ready_checklist['background_registered']:
|
| 752 |
+
print('[WARNING] Register background image first! Request ignored.')
|
| 753 |
+
return
|
| 754 |
+
|
| 755 |
+
### Register new masks
|
| 756 |
+
|
| 757 |
+
if isinstance(masks, Image.Image):
|
| 758 |
+
masks = [masks]
|
| 759 |
+
n = len(masks) if masks is not None else 0
|
| 760 |
+
|
| 761 |
+
# Modificiation.
|
| 762 |
+
masks, mask_strengths, mask_stds, original_masks = self.process_mask(masks, mask_strengths, mask_stds)
|
| 763 |
+
|
| 764 |
+
self.counts = masks.sum(dim=0) # (T, 1, h, w)
|
| 765 |
+
self.bg_mask = (1 - self.counts).clip_(0, 1) # (T, 1, h, w)
|
| 766 |
+
self.masks = masks # (p, T, 1, h, w)
|
| 767 |
+
self.mask_strengths = mask_strengths # (p,)
|
| 768 |
+
self.mask_stds = mask_stds # (p,)
|
| 769 |
+
self.original_masks = original_masks # (p, 1, h, w)
|
| 770 |
+
|
| 771 |
+
if p > n:
|
| 772 |
+
# Add more masks: counts and bg_masks are not changed, but only masks are changed.
|
| 773 |
+
self.masks = torch.cat((
|
| 774 |
+
self.masks,
|
| 775 |
+
torch.zeros(
|
| 776 |
+
(p - n, self.batch_size, 1, self.latent_height, self.latent_width),
|
| 777 |
+
dtype=self.dtype,
|
| 778 |
+
device=self.device,
|
| 779 |
+
),
|
| 780 |
+
), dim=0)
|
| 781 |
+
print(f'[WARNING] Detected more prompts ({p}) than masks ({n}). '
|
| 782 |
+
'Automatically adds blank masks for the additional prompts.')
|
| 783 |
+
elif p < n:
|
| 784 |
+
# Warns user to add more prompts.
|
| 785 |
+
print(f'[WARNING] Detected more masks ({n}) than prompts ({p}). '
|
| 786 |
+
'Additional masks are ignored until more prompts are provided.')
|
| 787 |
+
|
| 788 |
@torch.no_grad()
|
| 789 |
def update_single_layer(
|
| 790 |
self,
|