Asko Relas commited on
Commit
4e3d59f
Β·
1 Parent(s): 1fcbe69

model switching

Browse files
Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -33,28 +33,48 @@ def callback_cfg_cutoff(pipeline, step_index, timestep, callback_kwargs):
33
 
34
 
35
  MODELS = {
 
36
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
 
 
 
37
  }
38
 
 
 
39
  controlnet_model = ControlNetUnionModel.from_pretrained(
40
  "OzzyGT/controlnet-union-promax-sdxl-1.0", variant="fp16", torch_dtype=torch.float16
41
  )
42
  controlnet_model.to(device="cuda", dtype=torch.float16)
43
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
44
 
45
- pipe = DiffusionPipeline.from_pretrained(
46
- "SG161222/RealVisXL_V5.0_Lightning",
47
- torch_dtype=torch.float16,
48
- vae=vae,
49
- controlnet=controlnet_model,
50
- custom_pipeline="OzzyGT/custom_sdxl_cnet_union",
51
- ).to("cuda")
52
 
53
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  @spaces.GPU(duration=24)
57
  def fill_image(prompt, negative_prompt, image, model_selection, paste_back):
 
 
 
 
 
 
58
  (
59
  prompt_embeds,
60
  negative_prompt_embeds,
@@ -146,7 +166,7 @@ with gr.Blocks() as demo:
146
 
147
  use_as_input_button = gr.Button("Use as Input Image", visible=False)
148
 
149
- model_selection = gr.Dropdown(choices=list(MODELS.keys()), value="RealVisXL V5.0 Lightning", label="Model")
150
 
151
  def use_output_as_input(output_image):
152
  return gr.update(value=output_image[1])
 
33
 
34
 
35
  MODELS = {
36
+ "DreamShaper XL Turbo": "Lykon/dreamshaper-xl-v2-turbo",
37
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
38
+ "Playground v2.5": "playgroundai/playground-v2.5-1024px-aesthetic",
39
+ "Juggernaut XL Lightning": "RunDiffusion/Juggernaut-XL-Lightning",
40
+ "LEOSAM HelloWorld XL": "Leosam/HelloWorld-SDXL-Lightning",
41
  }
42
 
43
+ DEFAULT_MODEL = "DreamShaper XL Turbo"
44
+
45
  controlnet_model = ControlNetUnionModel.from_pretrained(
46
  "OzzyGT/controlnet-union-promax-sdxl-1.0", variant="fp16", torch_dtype=torch.float16
47
  )
48
  controlnet_model.to(device="cuda", dtype=torch.float16)
49
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
50
 
 
 
 
 
 
 
 
51
 
52
+ def load_pipeline(model_name):
53
+ """Load a pipeline for the given model name."""
54
+ model_id = MODELS[model_name]
55
+ pipeline = DiffusionPipeline.from_pretrained(
56
+ model_id,
57
+ torch_dtype=torch.float16,
58
+ vae=vae,
59
+ controlnet=controlnet_model,
60
+ custom_pipeline="OzzyGT/custom_sdxl_cnet_union",
61
+ ).to("cuda")
62
+ pipeline.scheduler = TCDScheduler.from_config(pipeline.scheduler.config)
63
+ return pipeline
64
+
65
+
66
+ current_model = DEFAULT_MODEL
67
+ pipe = load_pipeline(current_model)
68
 
69
 
70
  @spaces.GPU(duration=24)
71
  def fill_image(prompt, negative_prompt, image, model_selection, paste_back):
72
+ global pipe, current_model
73
+
74
+ if model_selection != current_model:
75
+ pipe = load_pipeline(model_selection)
76
+ current_model = model_selection
77
+
78
  (
79
  prompt_embeds,
80
  negative_prompt_embeds,
 
166
 
167
  use_as_input_button = gr.Button("Use as Input Image", visible=False)
168
 
169
+ model_selection = gr.Dropdown(choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Model")
170
 
171
  def use_output_as_input(output_image):
172
  return gr.update(value=output_image[1])