sidharthg commited on
Commit
e4a95ff
·
verified ·
1 Parent(s): 76498b5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -13
app.py CHANGED
@@ -185,26 +185,37 @@ def initialize_pipeline():
185
  return pipe
186
 
187
  def load_style(style_name):
188
- """Load a textual inversion style"""
189
  global current_style, pipe
190
 
191
  if pipe is None:
192
  initialize_pipeline()
193
 
194
- if style_name != current_style:
195
- style_config = STYLES[style_name]
196
- print(f"Loading style: {style_name}")
 
 
 
197
  device = "cuda" if torch.cuda.is_available() else "cpu"
198
 
199
- # Load the inversion
200
- pipe.load_textual_inversion(style_config["repo"])
201
-
202
- # Crucial: move back to device as load_textual_inversion
203
- # can sometimes mess with device placement of embeddings
204
- pipe.to(device)
205
-
206
- current_style = style_name
207
- print(f"Style loaded and verified on {device}")
 
 
 
 
 
 
 
 
208
 
209
  def generate_image(prompt, style_name, seed, num_inference_steps, guidance_scale, contrast_scale, complexity_scale, vibrancy_scale, num_images=3):
210
  """Generate multiple images with the selected style"""
 
185
  return pipe
186
 
187
  def load_style(style_name):
188
+ """Load a textual inversion style idempotently"""
189
  global current_style, pipe
190
 
191
  if pipe is None:
192
  initialize_pipeline()
193
 
194
+ style_config = STYLES[style_name]
195
+ token = style_config["token"]
196
+
197
+ # Check if the token is already in the tokenizer to avoid ValueError
198
+ if token not in pipe.tokenizer.get_vocab():
199
+ print(f"Loading style: {style_name} with token {token}")
200
  device = "cuda" if torch.cuda.is_available() else "cpu"
201
 
202
+ try:
203
+ # Load the inversion
204
+ pipe.load_textual_inversion(style_config["repo"])
205
+ # Crucial: move back to device as load_textual_inversion
206
+ # can sometimes mess with device placement of embeddings
207
+ pipe.to(device)
208
+ print(f"Style {style_name} loaded successfully")
209
+ except Exception as e:
210
+ print(f"Error loading style {style_name}: {e}")
211
+ if "already in tokenizer vocabulary" in str(e):
212
+ print(f"Token {token} already exists, skipping load.")
213
+ else:
214
+ raise e
215
+ else:
216
+ print(f"Style {style_name} (token {token}) already in tokenizer, skipping load.")
217
+
218
+ current_style = style_name
219
 
220
  def generate_image(prompt, style_name, seed, num_inference_steps, guidance_scale, contrast_scale, complexity_scale, vibrancy_scale, num_images=3):
221
  """Generate multiple images with the selected style"""