multimodalart HF Staff commited on
Commit
60f16ec
·
verified ·
1 Parent(s): 5246183

Rebuild guidance_schedule to match step count

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -37,16 +37,21 @@ def generate(
37
  seed = random.randint(0, MAX_SEED)
38
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
39
 
 
40
  kwargs = dict(
41
  prompt=prompt,
42
  width=int(width),
43
  height=int(height),
44
- num_inference_steps=int(num_inference_steps),
45
  generator=generator,
46
  )
47
  if guidance_scale > 0:
48
  kwargs["guidance_scale"] = float(guidance_scale)
49
  kwargs["guidance_schedule"] = None
 
 
 
 
50
 
51
  image = pipe(**kwargs).images[0]
52
  return image, seed
 
37
  seed = random.randint(0, MAX_SEED)
38
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
39
 
40
+ steps = int(num_inference_steps)
41
  kwargs = dict(
42
  prompt=prompt,
43
  width=int(width),
44
  height=int(height),
45
+ num_inference_steps=steps,
46
  generator=generator,
47
  )
48
  if guidance_scale > 0:
49
  kwargs["guidance_scale"] = float(guidance_scale)
50
  kwargs["guidance_schedule"] = None
51
+ else:
52
+ # PR default is len 48 (7.0 x45 + 3.0 x3); rebuild it for any step count.
53
+ tail = min(3, max(0, steps - 1))
54
+ kwargs["guidance_schedule"] = (7.0,) * (steps - tail) + (3.0,) * tail
55
 
56
  image = pipe(**kwargs).images[0]
57
  return image, seed