neochar / utils.py
lqume's picture
Fixed tokenizer bug and gibberish image generation. Use AutoTokenizer instead of MT5Tokenizer
ba80d6e verified
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image as PILImage, ImageDraw, ImageFont
from imwatermark import WatermarkEncoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.utils.torch_utils import randn_tensor
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import List, Optional, Tuple, Union
# Determine device and torch dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_name = "google/mt5-small" # or base / large / etc.
# Load tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=False, # use slow (SentencePiece) tokenizer to avoid subtle fast/slow differences
legacy=True # enforce legacy behavior (so that “new vs legacy” mismatch warnings are avoided)
)
# Load model:
encoder_model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch_dtype, # or whatever dtype you want (float32/float16/bfloat16)
device_map="auto" # or device=device if you want to manually move
)
encoder_model.eval()
class QPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
def add_watermark(self, img: PILImage.Image) -> PILImage.Image:
# Resize image to 256, as 128 is too small for watermark
img = img.resize((256, 256), resample=PILImage.BICUBIC)
watermark_str = os.getenv("WATERMARK_URL", "hf.co/lqume/new-hanzi")
encoder = WatermarkEncoder()
encoder.set_watermark('bytes', watermark_str.encode('utf-8'))
# Convert PIL image to NumPy array
img_np = np.asarray(img.convert("RGB")) # ensure 3-channel RGB
watermarked_np = encoder.encode(img_np, 'dwtDct')
# Convert back to PIL
return PILImage.fromarray(watermarked_np)
@torch.no_grad()
def __call__(
self,
texts: List[str],
batch_size: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple[List[PILImage.Image]]]:
batch_size = len(texts)
# Tokenize input text
tokenized = tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=48
)
input_ids = tokenized["input_ids"].to(device=device, dtype=torch.long)
attention_mask = tokenized["attention_mask"].to(device=device, dtype=torch.long)
# Encode to latent space
encoded = encoder_model.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Prepare noise tensor
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=torch_dtype)
# Run denoising loop
self.scheduler.set_timesteps(num_inference_steps)
for timestep in self.progress_bar(self.scheduler.timesteps):
noise_pred = self.unet(
image,
timestep,
encoder_hidden_states=encoded.last_hidden_state,
encoder_attention_mask=attention_mask.bool(),
return_dict=False
)[0]
image = self.scheduler.step(noise_pred, timestep, image, generator=generator, return_dict=False)[0]
# Final image post-processing
image = image.clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
image = [self.add_watermark(img) for img in image]
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)