Yinhong Liu commited on
Commit
56a4b2a
·
1 Parent(s): 71383c2

model selection

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import random
4
 
5
  # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -14,14 +14,22 @@ else:
14
  torch_dtype = torch.float32
15
 
16
  MODEL_OPTIONS = {
17
- "Sana": "sana-model-repo-id",
18
- "SD3": "sd3-model-repo-id",
19
- "Flux": "flux-model-repo-id"
20
  }
21
 
22
  def load_model(model_choice):
23
  model_repo_id = MODEL_OPTIONS[model_choice]
24
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
 
 
 
 
 
 
25
  pipe = pipe.to(device)
26
  return pipe
27
 
 
3
  import random
4
 
5
  # import spaces #[uncomment to use ZeroGPU]
6
+ from diffusers import SanaPipeline, StableDiffusion3Pipeline, FluxPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
14
  torch_dtype = torch.float32
15
 
16
  MODEL_OPTIONS = {
17
+ "Sana": "Efficient-Large-Model/Sana_1600M_1024px",
18
+ "SD3": "stabilityai/stable-diffusion-3-medium",
19
+ "Flux": "black-forest-labs/FLUX.1-dev"
20
  }
21
 
22
  def load_model(model_choice):
23
  model_repo_id = MODEL_OPTIONS[model_choice]
24
+ # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
25
+ if model_choice == 'Sana':
26
+ pipe = SanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
27
+ elif model_choice == 'SD3':
28
+ pipe = StableDiffusion3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
29
+ else:
30
+ pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
31
+
32
+
33
  pipe = pipe.to(device)
34
  return pipe
35