Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import urllib.request | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import numpy as np | |
| import json | |
| class TextEncoder(nn.Module): | |
| def __init__(self, vocab_size, embed_dim=256, hidden_dim=512): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) | |
| self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True) | |
| self.fc = nn.Linear(hidden_dim * 2, hidden_dim) | |
| def forward(self, x): | |
| embedded = self.embedding(x) | |
| lstm_out, (hidden, _) = self.lstm(embedded) | |
| hidden_forward = hidden[-2, :, :] | |
| hidden_backward = hidden[-1, :, :] | |
| combined = torch.cat([hidden_forward, hidden_backward], dim=1) | |
| return self.fc(combined) | |
| class DownBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, time_emb_dim=256, text_emb_dim=512): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) | |
| self.norm1 = nn.BatchNorm2d(out_channels) | |
| self.norm2 = nn.BatchNorm2d(out_channels) | |
| self.time_mlp = nn.Sequential( | |
| nn.Linear(time_emb_dim, out_channels), nn.SiLU(), | |
| nn.Linear(out_channels, out_channels) | |
| ) | |
| self.text_mlp = nn.Sequential( | |
| nn.Linear(text_emb_dim, out_channels), nn.SiLU(), | |
| nn.Linear(out_channels, out_channels) | |
| ) | |
| self.pool = nn.MaxPool2d(2) | |
| def forward(self, x, t_emb, text_emb): | |
| h = self.conv1(x) | |
| h = self.norm1(h) | |
| t = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1) | |
| txt = self.text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1) | |
| h = h + t + txt | |
| h = F.relu(h) | |
| h = self.conv2(h) | |
| h = self.norm2(h) | |
| h = F.relu(h) | |
| return h, self.pool(h) | |
| class UpBlock(nn.Module): | |
| def __init__(self, in_channels, skip_channels, out_channels, time_emb_dim=256, text_emb_dim=512): | |
| super().__init__() | |
| self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| self.conv1 = nn.Conv2d(in_channels + skip_channels, out_channels, 3, padding=1) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) | |
| self.norm1 = nn.BatchNorm2d(out_channels) | |
| self.norm2 = nn.BatchNorm2d(out_channels) | |
| self.time_mlp = nn.Sequential( | |
| nn.Linear(time_emb_dim, out_channels), nn.SiLU(), | |
| nn.Linear(out_channels, out_channels) | |
| ) | |
| self.text_mlp = nn.Sequential( | |
| nn.Linear(text_emb_dim, out_channels), nn.SiLU(), | |
| nn.Linear(out_channels, out_channels) | |
| ) | |
| def forward(self, x, skip, t_emb, text_emb): | |
| x = self.up(x) | |
| x = torch.cat([x, skip], dim=1) | |
| h = self.conv1(x) | |
| h = self.norm1(h) | |
| t = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1) | |
| txt = self.text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1) | |
| h = h + t + txt | |
| h = F.relu(h) | |
| h = self.conv2(h) | |
| h = self.norm2(h) | |
| return F.relu(h) | |
| class DiffusionUNet(nn.Module): | |
| def __init__(self, vocab_size, image_channels=3, base_channels=64, time_emb_dim=256, text_emb_dim=512): | |
| super().__init__() | |
| self.text_encoder = TextEncoder(vocab_size, embed_dim=256, hidden_dim=text_emb_dim) | |
| self.time_mlp = nn.Sequential( | |
| nn.Linear(1, time_emb_dim), nn.SiLU(), | |
| nn.Linear(time_emb_dim, time_emb_dim), nn.SiLU(), | |
| nn.Linear(time_emb_dim, time_emb_dim) | |
| ) | |
| self.init_conv = nn.Conv2d(image_channels, base_channels, 3, padding=1) | |
| self.down1 = DownBlock(base_channels, base_channels, time_emb_dim, text_emb_dim) | |
| self.down2 = DownBlock(base_channels, base_channels * 2, time_emb_dim, text_emb_dim) | |
| self.bottleneck_conv1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1) | |
| self.bottleneck_conv2 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1) | |
| self.bottleneck_norm1 = nn.BatchNorm2d(base_channels * 2) | |
| self.bottleneck_norm2 = nn.BatchNorm2d(base_channels * 2) | |
| self.bottleneck_time_mlp = nn.Sequential( | |
| nn.Linear(time_emb_dim, base_channels * 2), nn.SiLU(), | |
| nn.Linear(base_channels * 2, base_channels * 2) | |
| ) | |
| self.bottleneck_text_mlp = nn.Sequential( | |
| nn.Linear(text_emb_dim, base_channels * 2), nn.SiLU(), | |
| nn.Linear(base_channels * 2, base_channels * 2) | |
| ) | |
| self.up1 = UpBlock(base_channels * 2, base_channels * 2, base_channels, time_emb_dim, text_emb_dim) | |
| self.up2 = UpBlock(base_channels, base_channels, base_channels, time_emb_dim, text_emb_dim) | |
| self.out_conv = nn.Conv2d(base_channels, image_channels, 1) | |
| def forward(self, x, timesteps, text_tokens): | |
| text_emb = self.text_encoder(text_tokens) | |
| t_emb = self.time_mlp(timesteps.unsqueeze(-1).float()) | |
| x1 = self.init_conv(x) | |
| x2, x2_pooled = self.down1(x1, t_emb, text_emb) | |
| x3, x3_pooled = self.down2(x2_pooled, t_emb, text_emb) | |
| h = self.bottleneck_conv1(x3_pooled) | |
| h = self.bottleneck_norm1(h) | |
| t = self.bottleneck_time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1) | |
| txt = self.bottleneck_text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1) | |
| h = h + t + txt | |
| h = F.relu(h) | |
| h = self.bottleneck_conv2(h) | |
| h = self.bottleneck_norm2(h) | |
| bottleneck = F.relu(h) | |
| d1 = self.up1(bottleneck, x3, t_emb, text_emb) | |
| d2 = self.up2(d1, x2, t_emb, text_emb) | |
| return self.out_conv(d2) | |
| class Diffusion: | |
| def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device='cuda'): | |
| self.timesteps = timesteps | |
| self.device = device | |
| self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device) | |
| self.alphas = 1 - self.betas | |
| self.alpha_bars = torch.cumprod(self.alphas, dim=0) | |
| def sample(self, model, text_tokens, image_size=64, steps=None, progress_callback=None): | |
| model.eval() | |
| if steps is None: | |
| steps = self.timesteps | |
| x = torch.randn(1, 3, image_size, image_size).to(self.device) | |
| for i, t in enumerate(reversed(range(steps))): | |
| t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long) | |
| predicted_noise = model(x, t_batch, text_tokens) | |
| alpha = self.alphas[t] | |
| alpha_bar = self.alpha_bars[t] | |
| beta = self.betas[t] | |
| if t > 0: | |
| noise = torch.randn_like(x) | |
| else: | |
| noise = 0 | |
| x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise) | |
| x = x + torch.sqrt(beta) * noise | |
| if progress_callback is not None: | |
| progress = (i + 1) / steps | |
| progress_callback(progress) | |
| model.train() | |
| return x | |
| model = None | |
| device = None | |
| vocab_data = None | |
| def download_file(url, filename): | |
| if not os.path.exists(filename): | |
| print(f"Downloading {filename}...") | |
| urllib.request.urlretrieve(url, filename) | |
| print(f"Downloaded {filename}") | |
| else: | |
| print(f"{filename} already exists") | |
| def initialize_model(): | |
| global model, device, vocab_data | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth" | |
| model_path = "newest.pth" | |
| download_file(model_url, model_path) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| vocab_data = { | |
| 'vocab': checkpoint['vocab'], | |
| 'word_to_idx': checkpoint['word_to_idx'], | |
| 'vocab_size': checkpoint['vocab_size'] | |
| } | |
| model = DiffusionUNet( | |
| vocab_size=vocab_data['vocab_size'], | |
| image_channels=3, | |
| base_channels=64 | |
| ).to(device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| print(f"Model loaded successfully! Vocab size: {vocab_data['vocab_size']}") | |
| return "✅ Model loaded successfully! You can now generate images." | |
| def tokenize_text(text, max_len=20): | |
| words = [w.strip('.,!?"\'') for w in text.lower().split()] | |
| tokens = words[:max_len] | |
| indices = [vocab_data['word_to_idx'].get(token, vocab_data['word_to_idx'].get('<UNK>', 1)) for token in tokens] | |
| while len(indices) < max_len: | |
| indices.append(0) | |
| return torch.tensor(indices).unsqueeze(0).to(device) | |
| def generate_image(prompt, progress=gr.Progress()): | |
| global model, device, vocab_data | |
| if model is None or vocab_data is None: | |
| return None | |
| progress(0, desc="Starting generation...") | |
| diffusion = Diffusion(timesteps=500, device=device) | |
| def update_progress(pct): | |
| progress(pct, desc=f"Generating... {pct*100:.1f}%") | |
| with torch.no_grad(): | |
| text_tokens = tokenize_text(prompt) | |
| generated = diffusion.sample( | |
| model, | |
| text_tokens, | |
| image_size=64, | |
| steps=500, | |
| progress_callback=update_progress | |
| ) | |
| progress(1.0, desc="Converting to image...") | |
| image = generated.cpu().squeeze(0) | |
| image = (image + 1) / 2 | |
| image = image.clamp(0, 1) | |
| image = image.permute(1, 2, 0).numpy() | |
| image = (image * 255).astype(np.uint8) | |
| return Image.fromarray(image) | |
| with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo: | |
| gr.Markdown("# 🎨 RandomDiffusion") | |
| gr.Markdown("Text-to-Image generation using diffusion model") | |
| status = gr.Textbox(label="Status", value="Loading model...", interactive=False) | |
| with gr.Row(): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| value="a beautiful landscape", | |
| placeholder="Enter your text prompt here..." | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Image", variant="primary") | |
| output_image = gr.Image(label="Generated Image", type="pil") | |
| demo.load( | |
| lambda: initialize_model(), | |
| outputs=[status] | |
| ) | |
| generate_btn.click( | |
| generate_image, | |
| inputs=[prompt_input], | |
| outputs=[output_image] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |