| | import numpy as np |
| | import torch |
| |
|
| | from ...utils import is_invisible_watermark_available |
| |
|
| |
|
| | if is_invisible_watermark_available(): |
| | from imwatermark import WatermarkEncoder |
| |
|
| |
|
| | |
| | WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 |
| | |
| | WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] |
| |
|
| |
|
| | class StableDiffusionXLWatermarker: |
| | def __init__(self): |
| | self.watermark = WATERMARK_BITS |
| | self.encoder = WatermarkEncoder() |
| |
|
| | self.encoder.set_watermark("bits", self.watermark) |
| |
|
| | def apply_watermark(self, images: torch.FloatTensor): |
| | |
| | if images.shape[-1] < 256: |
| | return images |
| |
|
| | images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() |
| |
|
| | images = [self.encoder.encode(image, "dwtDct") for image in images] |
| |
|
| | images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) |
| |
|
| | images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) |
| | return images |
| |
|