Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,10 +8,6 @@ from PIL import Image
|
|
| 8 |
import numpy as np
|
| 9 |
import json
|
| 10 |
|
| 11 |
-
# ============================================================================
|
| 12 |
-
# DIFFUSION Model Architecture (from your training code)
|
| 13 |
-
# ============================================================================
|
| 14 |
-
|
| 15 |
class TextEncoder(nn.Module):
|
| 16 |
def __init__(self, vocab_size, embed_dim=256, hidden_dim=512):
|
| 17 |
super().__init__()
|
|
@@ -169,7 +165,6 @@ class Diffusion:
|
|
| 169 |
x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
|
| 170 |
x = x + torch.sqrt(beta) * noise
|
| 171 |
|
| 172 |
-
# Report progress
|
| 173 |
if progress_callback is not None:
|
| 174 |
progress = (i + 1) / steps
|
| 175 |
progress_callback(progress)
|
|
@@ -178,13 +173,11 @@ class Diffusion:
|
|
| 178 |
return x
|
| 179 |
|
| 180 |
|
| 181 |
-
# Global variables
|
| 182 |
model = None
|
| 183 |
device = None
|
| 184 |
vocab_data = None
|
| 185 |
|
| 186 |
def download_file(url, filename):
|
| 187 |
-
"""Download with progress tracking"""
|
| 188 |
if not os.path.exists(filename):
|
| 189 |
print(f"Downloading {filename}...")
|
| 190 |
urllib.request.urlretrieve(url, filename)
|
|
@@ -192,36 +185,30 @@ def download_file(url, filename):
|
|
| 192 |
else:
|
| 193 |
print(f"{filename} already exists")
|
| 194 |
|
| 195 |
-
# Download and load model
|
| 196 |
def initialize_model():
|
| 197 |
global model, device, vocab_data
|
| 198 |
|
| 199 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 200 |
|
| 201 |
-
# Download model and vocab
|
| 202 |
model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
|
| 203 |
model_path = "newest.pth"
|
| 204 |
|
| 205 |
download_file(model_url, model_path)
|
| 206 |
|
| 207 |
-
# Load checkpoint
|
| 208 |
checkpoint = torch.load(model_path, map_location=device)
|
| 209 |
|
| 210 |
-
# Get vocab info from checkpoint
|
| 211 |
vocab_data = {
|
| 212 |
'vocab': checkpoint['vocab'],
|
| 213 |
'word_to_idx': checkpoint['word_to_idx'],
|
| 214 |
'vocab_size': checkpoint['vocab_size']
|
| 215 |
}
|
| 216 |
|
| 217 |
-
# Create model with correct vocab size
|
| 218 |
model = DiffusionUNet(
|
| 219 |
vocab_size=vocab_data['vocab_size'],
|
| 220 |
image_channels=3,
|
| 221 |
base_channels=64
|
| 222 |
).to(device)
|
| 223 |
|
| 224 |
-
# Load state dict
|
| 225 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 226 |
model.eval()
|
| 227 |
|
|
@@ -229,15 +216,13 @@ def initialize_model():
|
|
| 229 |
return "✅ Model loaded successfully! You can now generate images."
|
| 230 |
|
| 231 |
def tokenize_text(text, max_len=20):
|
| 232 |
-
"""Tokenize text input for the model"""
|
| 233 |
words = [w.strip('.,!?"\'') for w in text.lower().split()]
|
| 234 |
tokens = words[:max_len]
|
| 235 |
indices = [vocab_data['word_to_idx'].get(token, vocab_data['word_to_idx'].get('<UNK>', 1)) for token in tokens]
|
| 236 |
while len(indices) < max_len:
|
| 237 |
-
indices.append(0)
|
| 238 |
return torch.tensor(indices).unsqueeze(0).to(device)
|
| 239 |
|
| 240 |
-
# Generate image with progress
|
| 241 |
def generate_image(prompt, progress=gr.Progress()):
|
| 242 |
global model, device, vocab_data
|
| 243 |
|
|
@@ -246,7 +231,7 @@ def generate_image(prompt, progress=gr.Progress()):
|
|
| 246 |
|
| 247 |
progress(0, desc="Starting generation...")
|
| 248 |
|
| 249 |
-
diffusion = Diffusion(timesteps=500, device=device)
|
| 250 |
|
| 251 |
def update_progress(pct):
|
| 252 |
progress(pct, desc=f"Generating... {pct*100:.1f}%")
|
|
@@ -263,7 +248,6 @@ def generate_image(prompt, progress=gr.Progress()):
|
|
| 263 |
|
| 264 |
progress(1.0, desc="Converting to image...")
|
| 265 |
|
| 266 |
-
# Convert to image
|
| 267 |
image = generated.cpu().squeeze(0)
|
| 268 |
image = (image + 1) / 2
|
| 269 |
image = image.clamp(0, 1)
|
|
@@ -272,7 +256,6 @@ def generate_image(prompt, progress=gr.Progress()):
|
|
| 272 |
|
| 273 |
return Image.fromarray(image)
|
| 274 |
|
| 275 |
-
# Create interface
|
| 276 |
with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
|
| 277 |
gr.Markdown("# 🎨 RandomDiffusion")
|
| 278 |
gr.Markdown("Text-to-Image generation using diffusion model")
|
|
@@ -291,13 +274,11 @@ with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
|
|
| 291 |
|
| 292 |
output_image = gr.Image(label="Generated Image", type="pil")
|
| 293 |
|
| 294 |
-
# Load model on startup
|
| 295 |
demo.load(
|
| 296 |
lambda: initialize_model(),
|
| 297 |
outputs=[status]
|
| 298 |
)
|
| 299 |
|
| 300 |
-
# Generate on button click
|
| 301 |
generate_btn.click(
|
| 302 |
generate_image,
|
| 303 |
inputs=[prompt_input],
|
|
@@ -305,4 +286,5 @@ with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
|
|
| 305 |
)
|
| 306 |
|
| 307 |
if __name__ == "__main__":
|
| 308 |
-
demo.launch()
|
|
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import json
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
class TextEncoder(nn.Module):
|
| 12 |
def __init__(self, vocab_size, embed_dim=256, hidden_dim=512):
|
| 13 |
super().__init__()
|
|
|
|
| 165 |
x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
|
| 166 |
x = x + torch.sqrt(beta) * noise
|
| 167 |
|
|
|
|
| 168 |
if progress_callback is not None:
|
| 169 |
progress = (i + 1) / steps
|
| 170 |
progress_callback(progress)
|
|
|
|
| 173 |
return x
|
| 174 |
|
| 175 |
|
|
|
|
| 176 |
model = None
|
| 177 |
device = None
|
| 178 |
vocab_data = None
|
| 179 |
|
| 180 |
def download_file(url, filename):
|
|
|
|
| 181 |
if not os.path.exists(filename):
|
| 182 |
print(f"Downloading {filename}...")
|
| 183 |
urllib.request.urlretrieve(url, filename)
|
|
|
|
| 185 |
else:
|
| 186 |
print(f"{filename} already exists")
|
| 187 |
|
|
|
|
| 188 |
def initialize_model():
|
| 189 |
global model, device, vocab_data
|
| 190 |
|
| 191 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 192 |
|
|
|
|
| 193 |
model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
|
| 194 |
model_path = "newest.pth"
|
| 195 |
|
| 196 |
download_file(model_url, model_path)
|
| 197 |
|
|
|
|
| 198 |
checkpoint = torch.load(model_path, map_location=device)
|
| 199 |
|
|
|
|
| 200 |
vocab_data = {
|
| 201 |
'vocab': checkpoint['vocab'],
|
| 202 |
'word_to_idx': checkpoint['word_to_idx'],
|
| 203 |
'vocab_size': checkpoint['vocab_size']
|
| 204 |
}
|
| 205 |
|
|
|
|
| 206 |
model = DiffusionUNet(
|
| 207 |
vocab_size=vocab_data['vocab_size'],
|
| 208 |
image_channels=3,
|
| 209 |
base_channels=64
|
| 210 |
).to(device)
|
| 211 |
|
|
|
|
| 212 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 213 |
model.eval()
|
| 214 |
|
|
|
|
| 216 |
return "✅ Model loaded successfully! You can now generate images."
|
| 217 |
|
| 218 |
def tokenize_text(text, max_len=20):
|
|
|
|
| 219 |
words = [w.strip('.,!?"\'') for w in text.lower().split()]
|
| 220 |
tokens = words[:max_len]
|
| 221 |
indices = [vocab_data['word_to_idx'].get(token, vocab_data['word_to_idx'].get('<UNK>', 1)) for token in tokens]
|
| 222 |
while len(indices) < max_len:
|
| 223 |
+
indices.append(0)
|
| 224 |
return torch.tensor(indices).unsqueeze(0).to(device)
|
| 225 |
|
|
|
|
| 226 |
def generate_image(prompt, progress=gr.Progress()):
|
| 227 |
global model, device, vocab_data
|
| 228 |
|
|
|
|
| 231 |
|
| 232 |
progress(0, desc="Starting generation...")
|
| 233 |
|
| 234 |
+
diffusion = Diffusion(timesteps=500, device=device)
|
| 235 |
|
| 236 |
def update_progress(pct):
|
| 237 |
progress(pct, desc=f"Generating... {pct*100:.1f}%")
|
|
|
|
| 248 |
|
| 249 |
progress(1.0, desc="Converting to image...")
|
| 250 |
|
|
|
|
| 251 |
image = generated.cpu().squeeze(0)
|
| 252 |
image = (image + 1) / 2
|
| 253 |
image = image.clamp(0, 1)
|
|
|
|
| 256 |
|
| 257 |
return Image.fromarray(image)
|
| 258 |
|
|
|
|
| 259 |
with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
|
| 260 |
gr.Markdown("# 🎨 RandomDiffusion")
|
| 261 |
gr.Markdown("Text-to-Image generation using diffusion model")
|
|
|
|
| 274 |
|
| 275 |
output_image = gr.Image(label="Generated Image", type="pil")
|
| 276 |
|
|
|
|
| 277 |
demo.load(
|
| 278 |
lambda: initialize_model(),
|
| 279 |
outputs=[status]
|
| 280 |
)
|
| 281 |
|
|
|
|
| 282 |
generate_btn.click(
|
| 283 |
generate_image,
|
| 284 |
inputs=[prompt_input],
|
|
|
|
| 286 |
)
|
| 287 |
|
| 288 |
if __name__ == "__main__":
|
| 289 |
+
demo.launch()
|
| 290 |
+
|