danhtran2mind commited on
Commit
9295811
·
verified ·
1 Parent(s): 9ecb3bf

Delete apps/gradio_app/old-image_generator.py

Browse files
apps/gradio_app/old-image_generator.py DELETED
@@ -1,77 +0,0 @@
1
- import torch
2
- from PIL import Image
3
- import numpy as np
4
- from transformers import CLIPTextModel, CLIPTokenizer
5
- from diffusers import (
6
- AutoencoderKL, UNet2DConditionModel,
7
- PNDMScheduler, StableDiffusionPipeline
8
- )
9
-
10
- from tqdm import tqdm
11
- from .config_loader import load_model_configs
12
-
13
- def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed,
14
- random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id,
15
- lora_scale, config_path, device, dtype):
16
- if not prompt or height % 8 != 0 or width % 8 != 0 or num_inference_steps not in range(1, 101) or \
17
- guidance_scale < 1.0 or guidance_scale > 20.0 or seed < 0 or seed > 4294967295 or \
18
- (use_lora and (lora_scale < 0.0 or lora_scale > 2.0)):
19
- return None, "Invalid input parameters."
20
-
21
- model_configs = load_model_configs(config_path)
22
- finetune_model_path = model_configs.get(finetune_model_id, {}).get('local_dir', finetune_model_id)
23
- lora_model_path = model_configs.get(lora_model_id, {}).get('local_dir', lora_model_id)
24
- base_model_path = model_configs.get(base_model_id, {}).get('local_dir', base_model_id)
25
-
26
- generator = torch.Generator(device=device).manual_seed(torch.randint(0, 4294967295, (1,)).item() if random_seed else int(seed))
27
-
28
- try:
29
- if use_lora:
30
- # Load base pipeline
31
- pipe = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=dtype, use_safetensors=True)
32
-
33
- # Add LoRA weights with specified rank and scale
34
- pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora",
35
- lora_scale=lora_scale)
36
-
37
- pipe = pipe.to(device)
38
- vae, tokenizer, text_encoder, unet, scheduler = pipe.vae, pipe.tokenizer, pipe.text_encoder, pipe.unet, PNDMScheduler.from_config(pipe.scheduler.config)
39
- else:
40
- vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
41
- tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
42
- text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
43
- unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
44
- scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
45
-
46
- text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
47
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
48
-
49
- uncond_input = tokenizer([""] * 1, padding="max_length", max_length=text_input.input_ids.shape[-1], return_tensors="pt")
50
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
51
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
52
-
53
- latents = torch.randn((1, unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=dtype, device=device)
54
- scheduler.set_timesteps(num_inference_steps)
55
- latents = latents * scheduler.init_noise_sigma
56
-
57
- for t in tqdm(scheduler.timesteps, desc="Generating image"):
58
- latent_model_input = torch.cat([latents] * 2)
59
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
60
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
61
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
62
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
63
- latents = scheduler.step(noise_pred, t, latents).prev_sample
64
-
65
- image = vae.decode(latents / vae.config.scaling_factor).sample
66
- image = (image / 2 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
67
- pil_image = Image.fromarray((image[0] * 255).round().astype("uint8"))
68
-
69
- if use_lora:
70
- del pipe
71
- else:
72
- del vae, tokenizer, text_encoder, unet, scheduler
73
- torch.cuda.empty_cache()
74
-
75
- return pil_image, f"Generated image successfully! Seed used: {seed}"
76
- except Exception as e:
77
- return None, f"Failed to generate image: {e}"