| | import os |
| | import numpy as np |
| | import torch |
| | from torch import nn |
| | from torch.utils.data import Dataset, DataLoader |
| | from torchvision import transforms as T |
| | from PIL import Image as PILImage, ImageDraw, ImageFont |
| | from imwatermark import WatermarkEncoder |
| |
|
| | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
| | from diffusers.utils.torch_utils import randn_tensor |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| |
|
| | model_name = "google/mt5-small" |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, |
| | use_fast=False, |
| | legacy=True |
| | ) |
| |
|
| | |
| | encoder_model = AutoModelForSeq2SeqLM.from_pretrained( |
| | model_name, |
| | torch_dtype=torch_dtype, |
| | device_map="auto" |
| | ) |
| |
|
| | encoder_model.eval() |
| |
|
| | class QPipeline(DiffusionPipeline): |
| | def __init__(self, unet, scheduler): |
| | super().__init__() |
| | self.register_modules(unet=unet, scheduler=scheduler) |
| |
|
| | def add_watermark(self, img: PILImage.Image) -> PILImage.Image: |
| | |
| | img = img.resize((256, 256), resample=PILImage.BICUBIC) |
| |
|
| | watermark_str = os.getenv("WATERMARK_URL", "hf.co/lqume/new-hanzi") |
| | encoder = WatermarkEncoder() |
| | encoder.set_watermark('bytes', watermark_str.encode('utf-8')) |
| |
|
| | |
| | img_np = np.asarray(img.convert("RGB")) |
| | watermarked_np = encoder.encode(img_np, 'dwtDct') |
| |
|
| | |
| | return PILImage.fromarray(watermarked_np) |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | texts: List[str], |
| | batch_size: int = 1, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | num_inference_steps: int = 20, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | ) -> Union[ImagePipelineOutput, Tuple[List[PILImage.Image]]]: |
| |
|
| | batch_size = len(texts) |
| |
|
| | |
| | tokenized = tokenizer( |
| | texts, |
| | return_tensors="pt", |
| | padding="max_length", |
| | truncation=True, |
| | max_length=48 |
| | ) |
| | input_ids = tokenized["input_ids"].to(device=device, dtype=torch.long) |
| | attention_mask = tokenized["attention_mask"].to(device=device, dtype=torch.long) |
| |
|
| | |
| | encoded = encoder_model.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| |
|
| | |
| | if isinstance(self.unet.config.sample_size, int): |
| | image_shape = ( |
| | batch_size, |
| | self.unet.config.in_channels, |
| | self.unet.config.sample_size, |
| | self.unet.config.sample_size, |
| | ) |
| | else: |
| | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) |
| |
|
| | image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=torch_dtype) |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps) |
| |
|
| | for timestep in self.progress_bar(self.scheduler.timesteps): |
| | noise_pred = self.unet( |
| | image, |
| | timestep, |
| | encoder_hidden_states=encoded.last_hidden_state, |
| | encoder_attention_mask=attention_mask.bool(), |
| | return_dict=False |
| | )[0] |
| |
|
| | image = self.scheduler.step(noise_pred, timestep, image, generator=generator, return_dict=False)[0] |
| |
|
| | |
| | image = image.clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() |
| | if output_type == "pil": |
| | image = self.numpy_to_pil(image) |
| | image = [self.add_watermark(img) for img in image] |
| |
|
| | if not return_dict: |
| | return (image,) |
| |
|
| | return ImagePipelineOutput(images=image) |
| |
|
| |
|