RandomDiffusion / app.py
lazerkat's picture
Update app.py
d4f89b8 verified
import gradio as gr
import os
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import json
class TextEncoder(nn.Module):
def __init__(self, vocab_size, embed_dim=256, hidden_dim=512):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, x):
embedded = self.embedding(x)
lstm_out, (hidden, _) = self.lstm(embedded)
hidden_forward = hidden[-2, :, :]
hidden_backward = hidden[-1, :, :]
combined = torch.cat([hidden_forward, hidden_backward], dim=1)
return self.fc(combined)
class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_emb_dim=256, text_emb_dim=512):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.norm1 = nn.BatchNorm2d(out_channels)
self.norm2 = nn.BatchNorm2d(out_channels)
self.time_mlp = nn.Sequential(
nn.Linear(time_emb_dim, out_channels), nn.SiLU(),
nn.Linear(out_channels, out_channels)
)
self.text_mlp = nn.Sequential(
nn.Linear(text_emb_dim, out_channels), nn.SiLU(),
nn.Linear(out_channels, out_channels)
)
self.pool = nn.MaxPool2d(2)
def forward(self, x, t_emb, text_emb):
h = self.conv1(x)
h = self.norm1(h)
t = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
txt = self.text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1)
h = h + t + txt
h = F.relu(h)
h = self.conv2(h)
h = self.norm2(h)
h = F.relu(h)
return h, self.pool(h)
class UpBlock(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels, time_emb_dim=256, text_emb_dim=512):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv1 = nn.Conv2d(in_channels + skip_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.norm1 = nn.BatchNorm2d(out_channels)
self.norm2 = nn.BatchNorm2d(out_channels)
self.time_mlp = nn.Sequential(
nn.Linear(time_emb_dim, out_channels), nn.SiLU(),
nn.Linear(out_channels, out_channels)
)
self.text_mlp = nn.Sequential(
nn.Linear(text_emb_dim, out_channels), nn.SiLU(),
nn.Linear(out_channels, out_channels)
)
def forward(self, x, skip, t_emb, text_emb):
x = self.up(x)
x = torch.cat([x, skip], dim=1)
h = self.conv1(x)
h = self.norm1(h)
t = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
txt = self.text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1)
h = h + t + txt
h = F.relu(h)
h = self.conv2(h)
h = self.norm2(h)
return F.relu(h)
class DiffusionUNet(nn.Module):
def __init__(self, vocab_size, image_channels=3, base_channels=64, time_emb_dim=256, text_emb_dim=512):
super().__init__()
self.text_encoder = TextEncoder(vocab_size, embed_dim=256, hidden_dim=text_emb_dim)
self.time_mlp = nn.Sequential(
nn.Linear(1, time_emb_dim), nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim), nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim)
)
self.init_conv = nn.Conv2d(image_channels, base_channels, 3, padding=1)
self.down1 = DownBlock(base_channels, base_channels, time_emb_dim, text_emb_dim)
self.down2 = DownBlock(base_channels, base_channels * 2, time_emb_dim, text_emb_dim)
self.bottleneck_conv1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1)
self.bottleneck_conv2 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1)
self.bottleneck_norm1 = nn.BatchNorm2d(base_channels * 2)
self.bottleneck_norm2 = nn.BatchNorm2d(base_channels * 2)
self.bottleneck_time_mlp = nn.Sequential(
nn.Linear(time_emb_dim, base_channels * 2), nn.SiLU(),
nn.Linear(base_channels * 2, base_channels * 2)
)
self.bottleneck_text_mlp = nn.Sequential(
nn.Linear(text_emb_dim, base_channels * 2), nn.SiLU(),
nn.Linear(base_channels * 2, base_channels * 2)
)
self.up1 = UpBlock(base_channels * 2, base_channels * 2, base_channels, time_emb_dim, text_emb_dim)
self.up2 = UpBlock(base_channels, base_channels, base_channels, time_emb_dim, text_emb_dim)
self.out_conv = nn.Conv2d(base_channels, image_channels, 1)
def forward(self, x, timesteps, text_tokens):
text_emb = self.text_encoder(text_tokens)
t_emb = self.time_mlp(timesteps.unsqueeze(-1).float())
x1 = self.init_conv(x)
x2, x2_pooled = self.down1(x1, t_emb, text_emb)
x3, x3_pooled = self.down2(x2_pooled, t_emb, text_emb)
h = self.bottleneck_conv1(x3_pooled)
h = self.bottleneck_norm1(h)
t = self.bottleneck_time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
txt = self.bottleneck_text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1)
h = h + t + txt
h = F.relu(h)
h = self.bottleneck_conv2(h)
h = self.bottleneck_norm2(h)
bottleneck = F.relu(h)
d1 = self.up1(bottleneck, x3, t_emb, text_emb)
d2 = self.up2(d1, x2, t_emb, text_emb)
return self.out_conv(d2)
class Diffusion:
def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device='cuda'):
self.timesteps = timesteps
self.device = device
self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
@torch.no_grad()
def sample(self, model, text_tokens, image_size=64, steps=None, progress_callback=None):
model.eval()
if steps is None:
steps = self.timesteps
x = torch.randn(1, 3, image_size, image_size).to(self.device)
for i, t in enumerate(reversed(range(steps))):
t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long)
predicted_noise = model(x, t_batch, text_tokens)
alpha = self.alphas[t]
alpha_bar = self.alpha_bars[t]
beta = self.betas[t]
if t > 0:
noise = torch.randn_like(x)
else:
noise = 0
x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
x = x + torch.sqrt(beta) * noise
if progress_callback is not None:
progress = (i + 1) / steps
progress_callback(progress)
model.train()
return x
model = None
device = None
vocab_data = None
def download_file(url, filename):
if not os.path.exists(filename):
print(f"Downloading {filename}...")
urllib.request.urlretrieve(url, filename)
print(f"Downloaded {filename}")
else:
print(f"{filename} already exists")
def initialize_model():
global model, device, vocab_data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
model_path = "newest.pth"
download_file(model_url, model_path)
checkpoint = torch.load(model_path, map_location=device)
vocab_data = {
'vocab': checkpoint['vocab'],
'word_to_idx': checkpoint['word_to_idx'],
'vocab_size': checkpoint['vocab_size']
}
model = DiffusionUNet(
vocab_size=vocab_data['vocab_size'],
image_channels=3,
base_channels=64
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Model loaded successfully! Vocab size: {vocab_data['vocab_size']}")
return "✅ Model loaded successfully! You can now generate images."
def tokenize_text(text, max_len=20):
words = [w.strip('.,!?"\'') for w in text.lower().split()]
tokens = words[:max_len]
indices = [vocab_data['word_to_idx'].get(token, vocab_data['word_to_idx'].get('<UNK>', 1)) for token in tokens]
while len(indices) < max_len:
indices.append(0)
return torch.tensor(indices).unsqueeze(0).to(device)
def generate_image(prompt, progress=gr.Progress()):
global model, device, vocab_data
if model is None or vocab_data is None:
return None
progress(0, desc="Starting generation...")
diffusion = Diffusion(timesteps=500, device=device)
def update_progress(pct):
progress(pct, desc=f"Generating... {pct*100:.1f}%")
with torch.no_grad():
text_tokens = tokenize_text(prompt)
generated = diffusion.sample(
model,
text_tokens,
image_size=64,
steps=500,
progress_callback=update_progress
)
progress(1.0, desc="Converting to image...")
image = generated.cpu().squeeze(0)
image = (image + 1) / 2
image = image.clamp(0, 1)
image = image.permute(1, 2, 0).numpy()
image = (image * 255).astype(np.uint8)
return Image.fromarray(image)
with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
gr.Markdown("# 🎨 RandomDiffusion")
gr.Markdown("Text-to-Image generation using diffusion model")
status = gr.Textbox(label="Status", value="Loading model...", interactive=False)
with gr.Row():
prompt_input = gr.Textbox(
label="Prompt",
value="a beautiful landscape",
placeholder="Enter your text prompt here..."
)
with gr.Row():
generate_btn = gr.Button("Generate Image", variant="primary")
output_image = gr.Image(label="Generated Image", type="pil")
demo.load(
lambda: initialize_model(),
outputs=[status]
)
generate_btn.click(
generate_image,
inputs=[prompt_input],
outputs=[output_image]
)
if __name__ == "__main__":
demo.launch()