Shilpaj commited on
Commit
e7f5c3d
·
verified ·
1 Parent(s): 2cb4e09

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +23 -20
  2. requirements.txt +6 -6
app.py CHANGED
@@ -31,13 +31,16 @@ def load_model():
31
  "runwayml/stable-diffusion-v1-5",
32
  torch_dtype=torch.float16,
33
  safety_checker=None
34
- ).to(device)
35
 
36
- # Create pipeline instance
37
- pipe = load_model()
 
 
 
38
 
39
  # Load concept library
40
- concept_embeds, concept_tokens = load_concept_library(pipe)
41
 
42
  # Define art style concepts
43
  art_concepts = {
@@ -84,13 +87,13 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
84
  elif concept_style in art_concepts:
85
  # Generate concept embedding from text description
86
  concept_text = art_concepts[concept_style]
87
- concept_embedding = get_concept_embedding(concept_text, pipe.tokenizer, pipe.text_encoder, device)
88
 
89
  # Prep text
90
- text_input = pipe.tokenizer([prompt], padding="max_length", max_length=pipe.tokenizer.model_max_length,
91
  truncation=True, return_tensors="pt")
92
  with torch.inference_mode():
93
- text_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0]
94
 
95
  # Apply concept embedding influence if provided
96
  if concept_embedding is not None and concept_strength > 0:
@@ -104,34 +107,34 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
104
 
105
  # Unconditional embedding for classifier-free guidance
106
  max_length = text_input.input_ids.shape[-1]
107
- uncond_input = pipe.tokenizer(
108
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
109
  )
110
  with torch.inference_mode():
111
- uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0]
112
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
113
 
114
  # Prep Scheduler
115
- set_timesteps(pipe.scheduler, num_inference_steps)
116
 
117
  # Prep latents
118
  latents = torch.randn(
119
- (batch_size, pipe.unet.in_channels, height // 8, width // 8),
120
  generator=generator,
121
  )
122
  latents = latents.to(device)
123
- latents = latents * pipe.scheduler.init_noise_sigma
124
 
125
  # Loop through diffusion process
126
- for i, t in tqdm(enumerate(pipe.scheduler.timesteps), total=len(pipe.scheduler.timesteps)):
127
  # Expand latents for classifier-free guidance
128
  latent_model_input = torch.cat([latents] * 2)
129
- sigma = pipe.scheduler.sigmas[i]
130
- latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
131
 
132
  # Predict the noise residual
133
  with torch.inference_mode():
134
- noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
135
 
136
  # Perform classifier-free guidance
137
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -146,7 +149,7 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
146
  latents_x0 = latents - sigma * noise_pred
147
 
148
  # Decode to image space
149
- denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
150
 
151
  # Calculate loss
152
  loss = vignette_loss(denoised_images) * vignette_loss_scale
@@ -158,7 +161,7 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
158
  latents = latents.detach() - cond_grad * sigma**2
159
 
160
  # Step with scheduler
161
- latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
162
 
163
  return latents
164
 
@@ -201,7 +204,7 @@ def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
201
  )
202
 
203
  # Convert latents to image
204
- images = latents_to_pil(latents, pipe.vae)
205
 
206
  return images[0]
207
 
@@ -241,7 +244,7 @@ def generate_style_grid(prompt, seed=42, num_inference_steps=30, guidance_scale=
241
  )
242
 
243
  # Convert latents to image
244
- style_images = latents_to_pil(latents, pipe.vae)
245
  images.append(style_images[0])
246
  labels.append(style)
247
 
 
31
  "runwayml/stable-diffusion-v1-5",
32
  torch_dtype=torch.float16,
33
  safety_checker=None
34
+ )
35
 
36
+ @spaces.GPU
37
+ @gr.Cache()
38
+ def get_pipeline():
39
+ pipe = load_model()
40
+ return pipe.to("cuda")
41
 
42
  # Load concept library
43
+ concept_embeds, concept_tokens = load_concept_library(get_pipeline())
44
 
45
  # Define art style concepts
46
  art_concepts = {
 
87
  elif concept_style in art_concepts:
88
  # Generate concept embedding from text description
89
  concept_text = art_concepts[concept_style]
90
+ concept_embedding = get_concept_embedding(concept_text, get_pipeline().tokenizer, get_pipeline().text_encoder, device)
91
 
92
  # Prep text
93
+ text_input = get_pipeline().tokenizer([prompt], padding="max_length", max_length=get_pipeline().tokenizer.model_max_length,
94
  truncation=True, return_tensors="pt")
95
  with torch.inference_mode():
96
+ text_embeddings = get_pipeline().text_encoder(text_input.input_ids.to(device))[0]
97
 
98
  # Apply concept embedding influence if provided
99
  if concept_embedding is not None and concept_strength > 0:
 
107
 
108
  # Unconditional embedding for classifier-free guidance
109
  max_length = text_input.input_ids.shape[-1]
110
+ uncond_input = get_pipeline().tokenizer(
111
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
112
  )
113
  with torch.inference_mode():
114
+ uncond_embeddings = get_pipeline().text_encoder(uncond_input.input_ids.to(device))[0]
115
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
116
 
117
  # Prep Scheduler
118
+ set_timesteps(get_pipeline().scheduler, num_inference_steps)
119
 
120
  # Prep latents
121
  latents = torch.randn(
122
+ (batch_size, get_pipeline().unet.in_channels, height // 8, width // 8),
123
  generator=generator,
124
  )
125
  latents = latents.to(device)
126
+ latents = latents * get_pipeline().scheduler.init_noise_sigma
127
 
128
  # Loop through diffusion process
129
+ for i, t in tqdm(enumerate(get_pipeline().scheduler.timesteps), total=len(get_pipeline().scheduler.timesteps)):
130
  # Expand latents for classifier-free guidance
131
  latent_model_input = torch.cat([latents] * 2)
132
+ sigma = get_pipeline().scheduler.sigmas[i]
133
+ latent_model_input = get_pipeline().scheduler.scale_model_input(latent_model_input, t)
134
 
135
  # Predict the noise residual
136
  with torch.inference_mode():
137
+ noise_pred = get_pipeline().unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
138
 
139
  # Perform classifier-free guidance
140
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
 
149
  latents_x0 = latents - sigma * noise_pred
150
 
151
  # Decode to image space
152
+ denoised_images = get_pipeline().vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
153
 
154
  # Calculate loss
155
  loss = vignette_loss(denoised_images) * vignette_loss_scale
 
161
  latents = latents.detach() - cond_grad * sigma**2
162
 
163
  # Step with scheduler
164
+ latents = get_pipeline().scheduler.step(noise_pred, t, latents).prev_sample
165
 
166
  return latents
167
 
 
204
  )
205
 
206
  # Convert latents to image
207
+ images = latents_to_pil(latents, get_pipeline().vae)
208
 
209
  return images[0]
210
 
 
244
  )
245
 
246
  # Convert latents to image
247
+ style_images = latents_to_pil(latents, get_pipeline().vae)
248
  images.append(style_images[0])
249
  labels.append(style)
250
 
requirements.txt CHANGED
@@ -1,11 +1,11 @@
1
  # Core dependencies with pinned versions for compatibility
2
- torch>=2.0.1
3
- torchvision>=0.15.2
4
- diffusers>=0.28.0
5
  transformers>=4.38.0
6
- accelerate>=0.28.0
7
  ftfy>=6.1.1
8
- gradio>=4.25.0
9
  numpy>=1.22.0
10
  Pillow>=10.0.0
11
  tqdm>=4.64.0
@@ -13,7 +13,7 @@ huggingface-hub>=0.22.2
13
 
14
  # HF Spaces specific
15
  gradio-client>=0.15.0
16
- spaces>=0.32.0
17
 
18
  # Optional dependencies for better performance
19
  scipy>=1.9.0
 
1
  # Core dependencies with pinned versions for compatibility
2
+ torch==2.2.1
3
+ torchvision==0.17.1
4
+ diffusers==0.28.0
5
  transformers>=4.38.0
6
+ accelerate==0.28.0
7
  ftfy>=6.1.1
8
+ gradio==4.25.0
9
  numpy>=1.22.0
10
  Pillow>=10.0.0
11
  tqdm>=4.64.0
 
13
 
14
  # HF Spaces specific
15
  gradio-client>=0.15.0
16
+ spaces==0.32.0
17
 
18
  # Optional dependencies for better performance
19
  scipy>=1.9.0