import gradio as gr import os import urllib.request import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import numpy as np import json class TextEncoder(nn.Module): def __init__(self, vocab_size, embed_dim=256, hidden_dim=512): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True) self.fc = nn.Linear(hidden_dim * 2, hidden_dim) def forward(self, x): embedded = self.embedding(x) lstm_out, (hidden, _) = self.lstm(embedded) hidden_forward = hidden[-2, :, :] hidden_backward = hidden[-1, :, :] combined = torch.cat([hidden_forward, hidden_backward], dim=1) return self.fc(combined) class DownBlock(nn.Module): def __init__(self, in_channels, out_channels, time_emb_dim=256, text_emb_dim=512): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) self.norm1 = nn.BatchNorm2d(out_channels) self.norm2 = nn.BatchNorm2d(out_channels) self.time_mlp = nn.Sequential( nn.Linear(time_emb_dim, out_channels), nn.SiLU(), nn.Linear(out_channels, out_channels) ) self.text_mlp = nn.Sequential( nn.Linear(text_emb_dim, out_channels), nn.SiLU(), nn.Linear(out_channels, out_channels) ) self.pool = nn.MaxPool2d(2) def forward(self, x, t_emb, text_emb): h = self.conv1(x) h = self.norm1(h) t = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1) txt = self.text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1) h = h + t + txt h = F.relu(h) h = self.conv2(h) h = self.norm2(h) h = F.relu(h) return h, self.pool(h) class UpBlock(nn.Module): def __init__(self, in_channels, skip_channels, out_channels, time_emb_dim=256, text_emb_dim=512): super().__init__() self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv1 = nn.Conv2d(in_channels + skip_channels, out_channels, 3, padding=1) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) self.norm1 = nn.BatchNorm2d(out_channels) self.norm2 = nn.BatchNorm2d(out_channels) self.time_mlp = nn.Sequential( nn.Linear(time_emb_dim, out_channels), nn.SiLU(), nn.Linear(out_channels, out_channels) ) self.text_mlp = nn.Sequential( nn.Linear(text_emb_dim, out_channels), nn.SiLU(), nn.Linear(out_channels, out_channels) ) def forward(self, x, skip, t_emb, text_emb): x = self.up(x) x = torch.cat([x, skip], dim=1) h = self.conv1(x) h = self.norm1(h) t = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1) txt = self.text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1) h = h + t + txt h = F.relu(h) h = self.conv2(h) h = self.norm2(h) return F.relu(h) class DiffusionUNet(nn.Module): def __init__(self, vocab_size, image_channels=3, base_channels=64, time_emb_dim=256, text_emb_dim=512): super().__init__() self.text_encoder = TextEncoder(vocab_size, embed_dim=256, hidden_dim=text_emb_dim) self.time_mlp = nn.Sequential( nn.Linear(1, time_emb_dim), nn.SiLU(), nn.Linear(time_emb_dim, time_emb_dim), nn.SiLU(), nn.Linear(time_emb_dim, time_emb_dim) ) self.init_conv = nn.Conv2d(image_channels, base_channels, 3, padding=1) self.down1 = DownBlock(base_channels, base_channels, time_emb_dim, text_emb_dim) self.down2 = DownBlock(base_channels, base_channels * 2, time_emb_dim, text_emb_dim) self.bottleneck_conv1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1) self.bottleneck_conv2 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1) self.bottleneck_norm1 = nn.BatchNorm2d(base_channels * 2) self.bottleneck_norm2 = nn.BatchNorm2d(base_channels * 2) self.bottleneck_time_mlp = nn.Sequential( nn.Linear(time_emb_dim, base_channels * 2), nn.SiLU(), nn.Linear(base_channels * 2, base_channels * 2) ) self.bottleneck_text_mlp = nn.Sequential( nn.Linear(text_emb_dim, base_channels * 2), nn.SiLU(), nn.Linear(base_channels * 2, base_channels * 2) ) self.up1 = UpBlock(base_channels * 2, base_channels * 2, base_channels, time_emb_dim, text_emb_dim) self.up2 = UpBlock(base_channels, base_channels, base_channels, time_emb_dim, text_emb_dim) self.out_conv = nn.Conv2d(base_channels, image_channels, 1) def forward(self, x, timesteps, text_tokens): text_emb = self.text_encoder(text_tokens) t_emb = self.time_mlp(timesteps.unsqueeze(-1).float()) x1 = self.init_conv(x) x2, x2_pooled = self.down1(x1, t_emb, text_emb) x3, x3_pooled = self.down2(x2_pooled, t_emb, text_emb) h = self.bottleneck_conv1(x3_pooled) h = self.bottleneck_norm1(h) t = self.bottleneck_time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1) txt = self.bottleneck_text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1) h = h + t + txt h = F.relu(h) h = self.bottleneck_conv2(h) h = self.bottleneck_norm2(h) bottleneck = F.relu(h) d1 = self.up1(bottleneck, x3, t_emb, text_emb) d2 = self.up2(d1, x2, t_emb, text_emb) return self.out_conv(d2) class Diffusion: def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device='cuda'): self.timesteps = timesteps self.device = device self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device) self.alphas = 1 - self.betas self.alpha_bars = torch.cumprod(self.alphas, dim=0) @torch.no_grad() def sample(self, model, text_tokens, image_size=64, steps=None, progress_callback=None): model.eval() if steps is None: steps = self.timesteps x = torch.randn(1, 3, image_size, image_size).to(self.device) for i, t in enumerate(reversed(range(steps))): t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long) predicted_noise = model(x, t_batch, text_tokens) alpha = self.alphas[t] alpha_bar = self.alpha_bars[t] beta = self.betas[t] if t > 0: noise = torch.randn_like(x) else: noise = 0 x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise) x = x + torch.sqrt(beta) * noise if progress_callback is not None: progress = (i + 1) / steps progress_callback(progress) model.train() return x model = None device = None vocab_data = None def download_file(url, filename): if not os.path.exists(filename): print(f"Downloading {filename}...") urllib.request.urlretrieve(url, filename) print(f"Downloaded {filename}") else: print(f"{filename} already exists") def initialize_model(): global model, device, vocab_data device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth" model_path = "newest.pth" download_file(model_url, model_path) checkpoint = torch.load(model_path, map_location=device) vocab_data = { 'vocab': checkpoint['vocab'], 'word_to_idx': checkpoint['word_to_idx'], 'vocab_size': checkpoint['vocab_size'] } model = DiffusionUNet( vocab_size=vocab_data['vocab_size'], image_channels=3, base_channels=64 ).to(device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"Model loaded successfully! Vocab size: {vocab_data['vocab_size']}") return "✅ Model loaded successfully! You can now generate images." def tokenize_text(text, max_len=20): words = [w.strip('.,!?"\'') for w in text.lower().split()] tokens = words[:max_len] indices = [vocab_data['word_to_idx'].get(token, vocab_data['word_to_idx'].get('', 1)) for token in tokens] while len(indices) < max_len: indices.append(0) return torch.tensor(indices).unsqueeze(0).to(device) def generate_image(prompt, progress=gr.Progress()): global model, device, vocab_data if model is None or vocab_data is None: return None progress(0, desc="Starting generation...") diffusion = Diffusion(timesteps=500, device=device) def update_progress(pct): progress(pct, desc=f"Generating... {pct*100:.1f}%") with torch.no_grad(): text_tokens = tokenize_text(prompt) generated = diffusion.sample( model, text_tokens, image_size=64, steps=500, progress_callback=update_progress ) progress(1.0, desc="Converting to image...") image = generated.cpu().squeeze(0) image = (image + 1) / 2 image = image.clamp(0, 1) image = image.permute(1, 2, 0).numpy() image = (image * 255).astype(np.uint8) return Image.fromarray(image) with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo: gr.Markdown("# 🎨 RandomDiffusion") gr.Markdown("Text-to-Image generation using diffusion model") status = gr.Textbox(label="Status", value="Loading model...", interactive=False) with gr.Row(): prompt_input = gr.Textbox( label="Prompt", value="a beautiful landscape", placeholder="Enter your text prompt here..." ) with gr.Row(): generate_btn = gr.Button("Generate Image", variant="primary") output_image = gr.Image(label="Generated Image", type="pil") demo.load( lambda: initialize_model(), outputs=[status] ) generate_btn.click( generate_image, inputs=[prompt_input], outputs=[output_image] ) if __name__ == "__main__": demo.launch()