Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -146,14 +146,14 @@ class Diffusion:
|
|
| 146 |
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
|
| 147 |
|
| 148 |
@torch.no_grad()
|
| 149 |
-
def sample(self, model, text_tokens, image_size=64, steps=None):
|
| 150 |
model.eval()
|
| 151 |
if steps is None:
|
| 152 |
steps = self.timesteps
|
| 153 |
|
| 154 |
x = torch.randn(1, 3, image_size, image_size).to(self.device)
|
| 155 |
|
| 156 |
-
for t in reversed(range(steps)):
|
| 157 |
t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long)
|
| 158 |
predicted_noise = model(x, t_batch, text_tokens)
|
| 159 |
|
|
@@ -169,6 +169,11 @@ class Diffusion:
|
|
| 169 |
x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
|
| 170 |
x = x + torch.sqrt(beta) * noise
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
model.train()
|
| 173 |
return x
|
| 174 |
|
|
@@ -232,18 +237,31 @@ def tokenize_text(text, max_len=20):
|
|
| 232 |
indices.append(0) # PAD token
|
| 233 |
return torch.tensor(indices).unsqueeze(0).to(device)
|
| 234 |
|
| 235 |
-
# Generate image
|
| 236 |
-
def generate_image(prompt):
|
| 237 |
global model, device, vocab_data
|
| 238 |
|
| 239 |
if model is None or vocab_data is None:
|
| 240 |
return None
|
| 241 |
|
|
|
|
|
|
|
| 242 |
diffusion = Diffusion(timesteps=500, device=device) # Use 500 timesteps like training
|
| 243 |
|
|
|
|
|
|
|
|
|
|
| 244 |
with torch.no_grad():
|
| 245 |
text_tokens = tokenize_text(prompt)
|
| 246 |
-
generated = diffusion.sample(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
# Convert to image
|
| 249 |
image = generated.cpu().squeeze(0)
|
|
|
|
| 146 |
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
|
| 147 |
|
| 148 |
@torch.no_grad()
|
| 149 |
+
def sample(self, model, text_tokens, image_size=64, steps=None, progress_callback=None):
|
| 150 |
model.eval()
|
| 151 |
if steps is None:
|
| 152 |
steps = self.timesteps
|
| 153 |
|
| 154 |
x = torch.randn(1, 3, image_size, image_size).to(self.device)
|
| 155 |
|
| 156 |
+
for i, t in enumerate(reversed(range(steps))):
|
| 157 |
t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long)
|
| 158 |
predicted_noise = model(x, t_batch, text_tokens)
|
| 159 |
|
|
|
|
| 169 |
x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
|
| 170 |
x = x + torch.sqrt(beta) * noise
|
| 171 |
|
| 172 |
+
# Report progress
|
| 173 |
+
if progress_callback is not None:
|
| 174 |
+
progress = (i + 1) / steps
|
| 175 |
+
progress_callback(progress)
|
| 176 |
+
|
| 177 |
model.train()
|
| 178 |
return x
|
| 179 |
|
|
|
|
| 237 |
indices.append(0) # PAD token
|
| 238 |
return torch.tensor(indices).unsqueeze(0).to(device)
|
| 239 |
|
| 240 |
+
# Generate image with progress
|
| 241 |
+
def generate_image(prompt, progress=gr.Progress()):
|
| 242 |
global model, device, vocab_data
|
| 243 |
|
| 244 |
if model is None or vocab_data is None:
|
| 245 |
return None
|
| 246 |
|
| 247 |
+
progress(0, desc="Starting generation...")
|
| 248 |
+
|
| 249 |
diffusion = Diffusion(timesteps=500, device=device) # Use 500 timesteps like training
|
| 250 |
|
| 251 |
+
def update_progress(pct):
|
| 252 |
+
progress(pct, desc=f"Generating... {pct*100:.1f}%")
|
| 253 |
+
|
| 254 |
with torch.no_grad():
|
| 255 |
text_tokens = tokenize_text(prompt)
|
| 256 |
+
generated = diffusion.sample(
|
| 257 |
+
model,
|
| 258 |
+
text_tokens,
|
| 259 |
+
image_size=64,
|
| 260 |
+
steps=500,
|
| 261 |
+
progress_callback=update_progress
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
progress(1.0, desc="Converting to image...")
|
| 265 |
|
| 266 |
# Convert to image
|
| 267 |
image = generated.cpu().squeeze(0)
|