major-project / app.py
vish26's picture
import re fix
39f6bab verified
import gradio as gr
import numpy as np
import torch
import re
import os
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel, CLIPModel
from huggingface_hub import snapshot_download
import torch.nn.functional as F
from torchvision.transforms import transforms
from PIL import Image
# ── Load models ONCE at startup ──────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
LATENT_SIZE = 16
print("Downloading model weights...")
model_path = snapshot_download(
repo_id="vish26/latent-diffusion-model-128x128-batch8-lr-1e-5",
repo_type="model"
)
model_path = os.path.join(model_path, "epoch_7")
print("Loading models...")
unet = UNet2DConditionModel.from_pretrained(
os.path.join(model_path, "unet"),
use_safetensors=True,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True
).to(device)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch_dtype).to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch_dtype).to(device)
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch_dtype).to(device)
scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")
# Load training state
checkpoint = torch.load(os.path.join(model_path, "training_state.pth"), map_location=device)
unet.load_state_dict(checkpoint['model_state_dict'])
print(f"Model loaded! Last loss: {checkpoint['loss']}")
unet.eval()
text_encoder.eval()
vae.eval()
clip.eval()
print("All models ready!")
def is_valid_prompt(prompt):
# Empty or spaces
if not prompt or prompt.strip() == "":
return False
# Must contain at least one letter or number
if not re.search(r"[a-zA-Z0-9]", prompt):
return False
return True
# ── Generation function ───────────────────────────────────────────────────────
def generate_image(prompt, num_inference_steps=500, guidance_scale=7.5):
if not is_valid_prompt(prompt):
raise gr.Error("Please enter a valid prompt (not empty or special characters only).")
# Tokenize
text_input = tokenizer(
prompt, padding="max_length", max_length=77,
truncation=True, return_tensors="pt"
).to(device)
uncond_input = tokenizer(
[""], padding="max_length", max_length=77,
truncation=True, return_tensors="pt"
).to(device)
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids)[0]
uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Init latents
latents = torch.randn((1, 4, LATENT_SIZE, LATENT_SIZE)).to(device)
if torch_dtype == torch.float16:
latents = latents.half()
latents = latents * scheduler.init_noise_sigma
# Denoising loop
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
with torch.no_grad():
noise_pred = unet(
latent_model_input, t,
encoder_hidden_states=text_embeddings
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Decode
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float() # back to float32 for PIL
image = image.permute(0, 2, 3, 1).squeeze(0).numpy()
image = (image * 255).round().astype("uint8")
image = Image.fromarray(image)
# CLIP Score
# transform = transforms.Compose([
# transforms.Resize((224, 224)),
# transforms.ToTensor(),
# transforms.Normalize([0.5], [0.5])
# ])
# processed_image = transform(image).unsqueeze(0).to(device)
# if torch_dtype == torch.float16:
# processed_image = processed_image.half()
# with torch.no_grad():
# # βœ… Extract .image_embeds and .text_embeds from the output object
# # image_features = clip.get_image_features(pixel_values=processed_image)
# # image_features = F.normalize(image_features, dim=-1)
# # text_features = clip.get_text_features(input_ids=text_input.input_ids)
# # text_features = F.normalize(text_features, dim=-1)
# image_features = clip.get_image_features(pixel_values=processed_image)
# if hasattr(image_features, 'image_embeds'):
# image_features = image_features.image_embeds
# image_features = F.normalize(image_features, dim=-1)
# text_features = clip.get_text_features(input_ids=text_input.input_ids)
# if hasattr(text_features, 'text_embeds'):
# text_features = text_features.text_embeds
# text_features = F.normalize(text_features, dim=-1)
# clip_score = (image_features * text_features).sum(dim=-1).item()
return image
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks() as demo:
gr.Markdown("# πŸ–ΌοΈ Latent Diffusion Model β€” Text to Image")
with gr.Row():
with gr.Column():
prompt = gr.Text(label="Prompt", placeholder="e.g. people walking on street")
steps = gr.Slider(label="Inference Steps", minimum=100, maximum=1000, step=50, value=500)
guidance = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.5, value=7.5)
generate_button = gr.Button("Generate", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Image")
# clip_score = gr.Text(label="CLIP Score")
generate_button.click(
fn=generate_image,
inputs=[prompt, steps, guidance],
# outputs=[output_image, clip_score]
outputs=[output_image]
)
gr.Examples(
examples=[
["people walking on street", 500, 7.5],
["a dog playing in the park", 500, 7.5],
["sunset over mountains", 500, 7.5],
],
inputs=[prompt, steps, guidance]
)
if __name__ == "__main__":
demo.launch()