janannfndnd commited on
Commit
5befe50
·
verified ·
1 Parent(s): dd5caf9

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -0
main.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import FluxPipeline
3
+
4
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) # can replace schnell with dev
5
+ # to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
6
+ pipe.enable_sequential_cpu_offload()
7
+ pipe.vae.enable_slicing()
8
+ pipe.vae.enable_tiling()
9
+
10
+ pipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once
11
+
12
+ prompt = "A cat holding a sign that says hello world"
13
+ out = pipe(
14
+ prompt=prompt,
15
+ guidance_scale=4,
16
+ height=768,
17
+ width=1024,
18
+ num_inference_steps=4,
19
+ max_sequence_length=256,
20
+ ).images[0]
21
+ out.save("image.png")