YGu1998 commited on
Commit
2cc7f2c
·
verified ·
1 Parent(s): 503ef67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -33,9 +33,13 @@ MODEL_OPTIONS = {
33
  }
34
 
35
 
36
- def load_model(model_choice):
 
 
 
37
  model_repo_id = MODEL_OPTIONS[model_choice]
38
  time_scale = 1000.0
 
39
  if "Sana" in model_choice:
40
  pipe = SiDSanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch.bfloat16)
41
  if "Sprint" in model_choice:
@@ -46,6 +50,10 @@ def load_model(model_choice):
46
  pipe = SiDFluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
47
  else:
48
  raise ValueError(f"Unknown model type for: {model_choice}")
 
 
 
 
49
  pipe = pipe.to(device)
50
  return pipe, time_scale
51
 
@@ -69,9 +77,11 @@ def infer(
69
  seed = random.randint(0, MAX_SEED)
70
 
71
  generator = torch.Generator().manual_seed(seed)
72
-
73
  pipe, time_scale = load_model(model_choice)
74
 
 
 
75
  image = pipe(
76
  prompt=prompt,
77
  guidance_scale=1,
@@ -81,7 +91,10 @@ def infer(
81
  generator=generator,
82
  time_scale=time_scale,
83
  ).images[0]
84
-
 
 
 
85
  pipe.maybe_free_model_hooks()
86
  del pipe
87
  torch.cuda.empty_cache()
@@ -168,7 +181,7 @@ with gr.Blocks(css=css) as demo:
168
  minimum=1,
169
  maximum=4,
170
  step=1,
171
- value=2, # Replace with defaults that work for your model
172
  )
173
 
174
  gr.Examples(examples=examples, inputs=[prompt])
 
33
  }
34
 
35
 
36
+ def load_model(model_choice, progress=None):
37
+ if progress is not None:
38
+ progress(0.1, desc=f"Loading {model_choice}...")
39
+
40
  model_repo_id = MODEL_OPTIONS[model_choice]
41
  time_scale = 1000.0
42
+
43
  if "Sana" in model_choice:
44
  pipe = SiDSanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch.bfloat16)
45
  if "Sprint" in model_choice:
 
50
  pipe = SiDFluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
51
  else:
52
  raise ValueError(f"Unknown model type for: {model_choice}")
53
+
54
+ if progress is not None:
55
+ progress(0.5, desc=f"{model_choice} loaded")
56
+
57
  pipe = pipe.to(device)
58
  return pipe, time_scale
59
 
 
77
  seed = random.randint(0, MAX_SEED)
78
 
79
  generator = torch.Generator().manual_seed(seed)
80
+ progress(0.0, desc="Preparing model...")
81
  pipe, time_scale = load_model(model_choice)
82
 
83
+
84
+ progress(0.7, desc="Running inference...")
85
  image = pipe(
86
  prompt=prompt,
87
  guidance_scale=1,
 
91
  generator=generator,
92
  time_scale=time_scale,
93
  ).images[0]
94
+
95
+
96
+ progress(1.0, desc="Done")
97
+
98
  pipe.maybe_free_model_hooks()
99
  del pipe
100
  torch.cuda.empty_cache()
 
181
  minimum=1,
182
  maximum=4,
183
  step=1,
184
+ value=4, # Replace with defaults that work for your model
185
  )
186
 
187
  gr.Examples(examples=examples, inputs=[prompt])