Delete apps/gradio_app/old-image_generator.py
Browse files
apps/gradio_app/old-image_generator.py
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from PIL import Image
|
| 3 |
-
import numpy as np
|
| 4 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
| 5 |
-
from diffusers import (
|
| 6 |
-
AutoencoderKL, UNet2DConditionModel,
|
| 7 |
-
PNDMScheduler, StableDiffusionPipeline
|
| 8 |
-
)
|
| 9 |
-
|
| 10 |
-
from tqdm import tqdm
|
| 11 |
-
from .config_loader import load_model_configs
|
| 12 |
-
|
| 13 |
-
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 14 |
-
random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id,
|
| 15 |
-
lora_scale, config_path, device, dtype):
|
| 16 |
-
if not prompt or height % 8 != 0 or width % 8 != 0 or num_inference_steps not in range(1, 101) or \
|
| 17 |
-
guidance_scale < 1.0 or guidance_scale > 20.0 or seed < 0 or seed > 4294967295 or \
|
| 18 |
-
(use_lora and (lora_scale < 0.0 or lora_scale > 2.0)):
|
| 19 |
-
return None, "Invalid input parameters."
|
| 20 |
-
|
| 21 |
-
model_configs = load_model_configs(config_path)
|
| 22 |
-
finetune_model_path = model_configs.get(finetune_model_id, {}).get('local_dir', finetune_model_id)
|
| 23 |
-
lora_model_path = model_configs.get(lora_model_id, {}).get('local_dir', lora_model_id)
|
| 24 |
-
base_model_path = model_configs.get(base_model_id, {}).get('local_dir', base_model_id)
|
| 25 |
-
|
| 26 |
-
generator = torch.Generator(device=device).manual_seed(torch.randint(0, 4294967295, (1,)).item() if random_seed else int(seed))
|
| 27 |
-
|
| 28 |
-
try:
|
| 29 |
-
if use_lora:
|
| 30 |
-
# Load base pipeline
|
| 31 |
-
pipe = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=dtype, use_safetensors=True)
|
| 32 |
-
|
| 33 |
-
# Add LoRA weights with specified rank and scale
|
| 34 |
-
pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora",
|
| 35 |
-
lora_scale=lora_scale)
|
| 36 |
-
|
| 37 |
-
pipe = pipe.to(device)
|
| 38 |
-
vae, tokenizer, text_encoder, unet, scheduler = pipe.vae, pipe.tokenizer, pipe.text_encoder, pipe.unet, PNDMScheduler.from_config(pipe.scheduler.config)
|
| 39 |
-
else:
|
| 40 |
-
vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
|
| 41 |
-
tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
|
| 42 |
-
text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
| 43 |
-
unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
|
| 44 |
-
scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
|
| 45 |
-
|
| 46 |
-
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 47 |
-
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 48 |
-
|
| 49 |
-
uncond_input = tokenizer([""] * 1, padding="max_length", max_length=text_input.input_ids.shape[-1], return_tensors="pt")
|
| 50 |
-
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 51 |
-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 52 |
-
|
| 53 |
-
latents = torch.randn((1, unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=dtype, device=device)
|
| 54 |
-
scheduler.set_timesteps(num_inference_steps)
|
| 55 |
-
latents = latents * scheduler.init_noise_sigma
|
| 56 |
-
|
| 57 |
-
for t in tqdm(scheduler.timesteps, desc="Generating image"):
|
| 58 |
-
latent_model_input = torch.cat([latents] * 2)
|
| 59 |
-
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 60 |
-
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 61 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 62 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 63 |
-
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 64 |
-
|
| 65 |
-
image = vae.decode(latents / vae.config.scaling_factor).sample
|
| 66 |
-
image = (image / 2 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 67 |
-
pil_image = Image.fromarray((image[0] * 255).round().astype("uint8"))
|
| 68 |
-
|
| 69 |
-
if use_lora:
|
| 70 |
-
del pipe
|
| 71 |
-
else:
|
| 72 |
-
del vae, tokenizer, text_encoder, unet, scheduler
|
| 73 |
-
torch.cuda.empty_cache()
|
| 74 |
-
|
| 75 |
-
return pil_image, f"Generated image successfully! Seed used: {seed}"
|
| 76 |
-
except Exception as e:
|
| 77 |
-
return None, f"Failed to generate image: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|