lazerkat commited on
Commit
d4f89b8
·
verified ·
1 Parent(s): 4ad9a53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -22
app.py CHANGED
@@ -8,10 +8,6 @@ from PIL import Image
8
  import numpy as np
9
  import json
10
 
11
- # ============================================================================
12
- # DIFFUSION Model Architecture (from your training code)
13
- # ============================================================================
14
-
15
  class TextEncoder(nn.Module):
16
  def __init__(self, vocab_size, embed_dim=256, hidden_dim=512):
17
  super().__init__()
@@ -169,7 +165,6 @@ class Diffusion:
169
  x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
170
  x = x + torch.sqrt(beta) * noise
171
 
172
- # Report progress
173
  if progress_callback is not None:
174
  progress = (i + 1) / steps
175
  progress_callback(progress)
@@ -178,13 +173,11 @@ class Diffusion:
178
  return x
179
 
180
 
181
- # Global variables
182
  model = None
183
  device = None
184
  vocab_data = None
185
 
186
  def download_file(url, filename):
187
- """Download with progress tracking"""
188
  if not os.path.exists(filename):
189
  print(f"Downloading {filename}...")
190
  urllib.request.urlretrieve(url, filename)
@@ -192,36 +185,30 @@ def download_file(url, filename):
192
  else:
193
  print(f"{filename} already exists")
194
 
195
- # Download and load model
196
  def initialize_model():
197
  global model, device, vocab_data
198
 
199
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
200
 
201
- # Download model and vocab
202
  model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
203
  model_path = "newest.pth"
204
 
205
  download_file(model_url, model_path)
206
 
207
- # Load checkpoint
208
  checkpoint = torch.load(model_path, map_location=device)
209
 
210
- # Get vocab info from checkpoint
211
  vocab_data = {
212
  'vocab': checkpoint['vocab'],
213
  'word_to_idx': checkpoint['word_to_idx'],
214
  'vocab_size': checkpoint['vocab_size']
215
  }
216
 
217
- # Create model with correct vocab size
218
  model = DiffusionUNet(
219
  vocab_size=vocab_data['vocab_size'],
220
  image_channels=3,
221
  base_channels=64
222
  ).to(device)
223
 
224
- # Load state dict
225
  model.load_state_dict(checkpoint['model_state_dict'])
226
  model.eval()
227
 
@@ -229,15 +216,13 @@ def initialize_model():
229
  return "✅ Model loaded successfully! You can now generate images."
230
 
231
  def tokenize_text(text, max_len=20):
232
- """Tokenize text input for the model"""
233
  words = [w.strip('.,!?"\'') for w in text.lower().split()]
234
  tokens = words[:max_len]
235
  indices = [vocab_data['word_to_idx'].get(token, vocab_data['word_to_idx'].get('<UNK>', 1)) for token in tokens]
236
  while len(indices) < max_len:
237
- indices.append(0) # PAD token
238
  return torch.tensor(indices).unsqueeze(0).to(device)
239
 
240
- # Generate image with progress
241
  def generate_image(prompt, progress=gr.Progress()):
242
  global model, device, vocab_data
243
 
@@ -246,7 +231,7 @@ def generate_image(prompt, progress=gr.Progress()):
246
 
247
  progress(0, desc="Starting generation...")
248
 
249
- diffusion = Diffusion(timesteps=500, device=device) # Use 500 timesteps like training
250
 
251
  def update_progress(pct):
252
  progress(pct, desc=f"Generating... {pct*100:.1f}%")
@@ -263,7 +248,6 @@ def generate_image(prompt, progress=gr.Progress()):
263
 
264
  progress(1.0, desc="Converting to image...")
265
 
266
- # Convert to image
267
  image = generated.cpu().squeeze(0)
268
  image = (image + 1) / 2
269
  image = image.clamp(0, 1)
@@ -272,7 +256,6 @@ def generate_image(prompt, progress=gr.Progress()):
272
 
273
  return Image.fromarray(image)
274
 
275
- # Create interface
276
  with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
277
  gr.Markdown("# 🎨 RandomDiffusion")
278
  gr.Markdown("Text-to-Image generation using diffusion model")
@@ -291,13 +274,11 @@ with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
291
 
292
  output_image = gr.Image(label="Generated Image", type="pil")
293
 
294
- # Load model on startup
295
  demo.load(
296
  lambda: initialize_model(),
297
  outputs=[status]
298
  )
299
 
300
- # Generate on button click
301
  generate_btn.click(
302
  generate_image,
303
  inputs=[prompt_input],
@@ -305,4 +286,5 @@ with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
305
  )
306
 
307
  if __name__ == "__main__":
308
- demo.launch()
 
 
8
  import numpy as np
9
  import json
10
 
 
 
 
 
11
  class TextEncoder(nn.Module):
12
  def __init__(self, vocab_size, embed_dim=256, hidden_dim=512):
13
  super().__init__()
 
165
  x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
166
  x = x + torch.sqrt(beta) * noise
167
 
 
168
  if progress_callback is not None:
169
  progress = (i + 1) / steps
170
  progress_callback(progress)
 
173
  return x
174
 
175
 
 
176
  model = None
177
  device = None
178
  vocab_data = None
179
 
180
  def download_file(url, filename):
 
181
  if not os.path.exists(filename):
182
  print(f"Downloading {filename}...")
183
  urllib.request.urlretrieve(url, filename)
 
185
  else:
186
  print(f"{filename} already exists")
187
 
 
188
  def initialize_model():
189
  global model, device, vocab_data
190
 
191
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
192
 
 
193
  model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
194
  model_path = "newest.pth"
195
 
196
  download_file(model_url, model_path)
197
 
 
198
  checkpoint = torch.load(model_path, map_location=device)
199
 
 
200
  vocab_data = {
201
  'vocab': checkpoint['vocab'],
202
  'word_to_idx': checkpoint['word_to_idx'],
203
  'vocab_size': checkpoint['vocab_size']
204
  }
205
 
 
206
  model = DiffusionUNet(
207
  vocab_size=vocab_data['vocab_size'],
208
  image_channels=3,
209
  base_channels=64
210
  ).to(device)
211
 
 
212
  model.load_state_dict(checkpoint['model_state_dict'])
213
  model.eval()
214
 
 
216
  return "✅ Model loaded successfully! You can now generate images."
217
 
218
  def tokenize_text(text, max_len=20):
 
219
  words = [w.strip('.,!?"\'') for w in text.lower().split()]
220
  tokens = words[:max_len]
221
  indices = [vocab_data['word_to_idx'].get(token, vocab_data['word_to_idx'].get('<UNK>', 1)) for token in tokens]
222
  while len(indices) < max_len:
223
+ indices.append(0)
224
  return torch.tensor(indices).unsqueeze(0).to(device)
225
 
 
226
  def generate_image(prompt, progress=gr.Progress()):
227
  global model, device, vocab_data
228
 
 
231
 
232
  progress(0, desc="Starting generation...")
233
 
234
+ diffusion = Diffusion(timesteps=500, device=device)
235
 
236
  def update_progress(pct):
237
  progress(pct, desc=f"Generating... {pct*100:.1f}%")
 
248
 
249
  progress(1.0, desc="Converting to image...")
250
 
 
251
  image = generated.cpu().squeeze(0)
252
  image = (image + 1) / 2
253
  image = image.clamp(0, 1)
 
256
 
257
  return Image.fromarray(image)
258
 
 
259
  with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
260
  gr.Markdown("# 🎨 RandomDiffusion")
261
  gr.Markdown("Text-to-Image generation using diffusion model")
 
274
 
275
  output_image = gr.Image(label="Generated Image", type="pil")
276
 
 
277
  demo.load(
278
  lambda: initialize_model(),
279
  outputs=[status]
280
  )
281
 
 
282
  generate_btn.click(
283
  generate_image,
284
  inputs=[prompt_input],
 
286
  )
287
 
288
  if __name__ == "__main__":
289
+ demo.launch()
290
+