Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import re | |
| import os | |
| from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL | |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPModel | |
| from huggingface_hub import snapshot_download | |
| import torch.nn.functional as F | |
| from torchvision.transforms import transforms | |
| from PIL import Image | |
| # ββ Load models ONCE at startup ββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| LATENT_SIZE = 16 | |
| print("Downloading model weights...") | |
| model_path = snapshot_download( | |
| repo_id="vish26/latent-diffusion-model-128x128-batch8-lr-1e-5", | |
| repo_type="model" | |
| ) | |
| model_path = os.path.join(model_path, "epoch_7") | |
| print("Loading models...") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| os.path.join(model_path, "unet"), | |
| use_safetensors=True, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch_dtype).to(device) | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch_dtype).to(device) | |
| clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch_dtype).to(device) | |
| scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear") | |
| # Load training state | |
| checkpoint = torch.load(os.path.join(model_path, "training_state.pth"), map_location=device) | |
| unet.load_state_dict(checkpoint['model_state_dict']) | |
| print(f"Model loaded! Last loss: {checkpoint['loss']}") | |
| unet.eval() | |
| text_encoder.eval() | |
| vae.eval() | |
| clip.eval() | |
| print("All models ready!") | |
| def is_valid_prompt(prompt): | |
| # Empty or spaces | |
| if not prompt or prompt.strip() == "": | |
| return False | |
| # Must contain at least one letter or number | |
| if not re.search(r"[a-zA-Z0-9]", prompt): | |
| return False | |
| return True | |
| # ββ Generation function βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_image(prompt, num_inference_steps=500, guidance_scale=7.5): | |
| if not is_valid_prompt(prompt): | |
| raise gr.Error("Please enter a valid prompt (not empty or special characters only).") | |
| # Tokenize | |
| text_input = tokenizer( | |
| prompt, padding="max_length", max_length=77, | |
| truncation=True, return_tensors="pt" | |
| ).to(device) | |
| uncond_input = tokenizer( | |
| [""], padding="max_length", max_length=77, | |
| truncation=True, return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| text_embeddings = text_encoder(text_input.input_ids)[0] | |
| uncond_embeddings = text_encoder(uncond_input.input_ids)[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| # Init latents | |
| latents = torch.randn((1, 4, LATENT_SIZE, LATENT_SIZE)).to(device) | |
| if torch_dtype == torch.float16: | |
| latents = latents.half() | |
| latents = latents * scheduler.init_noise_sigma | |
| # Denoising loop | |
| scheduler.set_timesteps(num_inference_steps) | |
| for t in scheduler.timesteps: | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
| with torch.no_grad(): | |
| noise_pred = unet( | |
| latent_model_input, t, | |
| encoder_hidden_states=text_embeddings | |
| ).sample | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = scheduler.step(noise_pred, t, latents).prev_sample | |
| # Decode | |
| latents = 1 / 0.18215 * latents | |
| with torch.no_grad(): | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().float() # back to float32 for PIL | |
| image = image.permute(0, 2, 3, 1).squeeze(0).numpy() | |
| image = (image * 255).round().astype("uint8") | |
| image = Image.fromarray(image) | |
| # CLIP Score | |
| # transform = transforms.Compose([ | |
| # transforms.Resize((224, 224)), | |
| # transforms.ToTensor(), | |
| # transforms.Normalize([0.5], [0.5]) | |
| # ]) | |
| # processed_image = transform(image).unsqueeze(0).to(device) | |
| # if torch_dtype == torch.float16: | |
| # processed_image = processed_image.half() | |
| # with torch.no_grad(): | |
| # # β Extract .image_embeds and .text_embeds from the output object | |
| # # image_features = clip.get_image_features(pixel_values=processed_image) | |
| # # image_features = F.normalize(image_features, dim=-1) | |
| # # text_features = clip.get_text_features(input_ids=text_input.input_ids) | |
| # # text_features = F.normalize(text_features, dim=-1) | |
| # image_features = clip.get_image_features(pixel_values=processed_image) | |
| # if hasattr(image_features, 'image_embeds'): | |
| # image_features = image_features.image_embeds | |
| # image_features = F.normalize(image_features, dim=-1) | |
| # text_features = clip.get_text_features(input_ids=text_input.input_ids) | |
| # if hasattr(text_features, 'text_embeds'): | |
| # text_features = text_features.text_embeds | |
| # text_features = F.normalize(text_features, dim=-1) | |
| # clip_score = (image_features * text_features).sum(dim=-1).item() | |
| return image | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# πΌοΈ Latent Diffusion Model β Text to Image") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Text(label="Prompt", placeholder="e.g. people walking on street") | |
| steps = gr.Slider(label="Inference Steps", minimum=100, maximum=1000, step=50, value=500) | |
| guidance = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.5, value=7.5) | |
| generate_button = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image") | |
| # clip_score = gr.Text(label="CLIP Score") | |
| generate_button.click( | |
| fn=generate_image, | |
| inputs=[prompt, steps, guidance], | |
| # outputs=[output_image, clip_score] | |
| outputs=[output_image] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["people walking on street", 500, 7.5], | |
| ["a dog playing in the park", 500, 7.5], | |
| ["sunset over mountains", 500, 7.5], | |
| ], | |
| inputs=[prompt, steps, guidance] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |