import os import random from pathlib import Path import numpy as np import safetensors # import torch import torch.nn.functional as F import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from PIL import Image from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, CLIPTextModel from safetensors.torch import load_file import diffusers # from diffusers.pipelines import BlipDiffusionPipeline from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DiffusionPipeline from diffusers.loaders import AttnProcsLayers from diffusers.models.attention_processor import CustomDiffusionAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import load_image import streamlit as st import io import streamlit as st # 用于创建交互式网页UI import io # 处理文件流(后面用来生成下载按钮) # 设置页面标题和布局 st.set_page_config(page_title="Fine-tuning style diffusion", layout="wide") st.title("Fine-tuning style diffusion 推理 Demo") st.write("支持 **A reference.(风格) + 文本*") st.write("只是训练了一个提示词 'A reference.'") st.write("即使用该提示词时以十二生肖为主要元素进行新年图片风格的生成,例如使用一下提示词") st.write("A reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background") device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 # ========================== # 模型加载(缓存) # ========================== @st.cache_resource def load_models(): model_path = "./stable-diffusion-v1-5" tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained( model_path, subfolder="text_encoder", torch_dtype=torch.float16 ).to(device) vae = AutoencoderKL.from_pretrained( model_path, subfolder="vae", torch_dtype=torch.float16 ).to(device) unet = UNet2DConditionModel.from_pretrained( model_path, subfolder="unet", torch_dtype=torch.float16 ).to(device) attn_path = "output/pytorch_custom_diffusion_weights.bin" state_dict = torch.load(attn_path, map_location="cpu") unet.load_attn_procs(state_dict) token_path = "output/learned_embeds.safetensors" try: new_embed = torch.load(token_path) token_id = tokenizer.convert_tokens_to_ids("") text_encoder.get_input_embeddings().weight.data[token_id] = new_embed print("Loaded token embedding") except: print("No trained token found") scheduler = DDPMScheduler.from_pretrained( model_path, subfolder="scheduler" ) unet.enable_xformers_memory_efficient_attention() return tokenizer, text_encoder, vae, unet, scheduler tokenizer, text_encoder, vae, unet, scheduler = load_models() prompt = st.text_input( "Prompt", "A reference." ) # 调整参数 steps = st.slider("Steps", 10, 320, 100) guidance = st.slider("Guidance", 1.0, 18.0, 6.0) # ========================== # 图像预处理 # ========================== def preprocess(image): # 调整图像,转换为tensor(张量)并归一化到[-1,1] transform = transforms.Compose([ transforms.Resize((512,512)), transforms.ToTensor(), transforms.Normalize([0.5],[0.5]) ]) # 增加batch维度 return transform(image).unsqueeze(0) # ========================== # diffusion 推理 # ========================== def generate(prompt): with torch.no_grad(): # 文本向量化 text_input = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" ).to(device) text_emb = text_encoder(text_input.input_ids)[0] # 无条件 embedding; uncond_input = tokenizer( "", padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" ).to(device) uncond_emb = text_encoder(uncond_input.input_ids)[0] text_emb = torch.cat([uncond_emb, text_emb], dim=0) # 初始化噪声潜变量 latents = torch.randn( (1,4,64,64), device=device, dtype=torch.float16 ) # 设置diffusion时间步 scheduler.set_timesteps(steps) # ---------------- # diffusion loop # ---------------- # 采用 for t in scheduler.timesteps: # 为什么要拼接两份 latent_model_input = torch.cat([latents]*2) noise_pred = unet( latent_model_input, t, encoder_hidden_states=text_emb ).sample noise_uncond, noise_text = noise_pred.chunk(2) noise_pred = noise_uncond + guidance * ( noise_text - noise_uncond ) # 调度程序/潜在的 latents = scheduler.step( noise_pred, t, latents ).prev_sample # ---------------- # decode image;解码图像 # ---------------- # 解码生成图像;将latent解码成[0,1]的RGB图像 latents = latents / vae.config.scaling_factor image = vae.decode(latents).sample image = (image/2 + 0.5).clamp(0,1) # 转成numpy数组,再用PIL转成可展示的图像 image = image.cpu().permute(0,2,3,1).numpy()[0] image = (image*255).astype(np.uint8) return Image.fromarray(image) if st.button("Generate"): with st.spinner("Generating..."): image = generate(prompt) st.image(image,caption="Result",width=512) buf = io.BytesIO() image.save(buf,format="PNG") st.download_button( "Download", buf.getvalue(), "result.png" )