Mohansai2004 commited on
Commit
7d035ab
·
1 Parent(s): 104c977

feat: optimize image generator for CPU performance

Browse files

- Switch to RunwayML SD v1.5 model
- Optimize CPU settings and memory usage
- Simplify user interface
- Reduce generation time to 1-2 minutes
- Fix CPU offloading issues

Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -5,6 +5,9 @@ import gc
5
 
6
  @st.cache_resource
7
  def load_model():
 
 
 
8
  # Use RunwayML's smaller model
9
  model_id = "runwayml/stable-diffusion-v1-5"
10
  pipe = StableDiffusionPipeline.from_pretrained(
@@ -13,13 +16,16 @@ def load_model():
13
  safety_checker=None,
14
  requires_safety_checker=False,
15
  use_safetensors=True
16
- ).to("cpu")
17
 
18
- # Optimized settings
 
19
  pipe.enable_attention_slicing(slice_size=1)
20
- pipe.enable_vae_slicing()
21
- pipe.enable_sequential_cpu_offload()
22
- torch.set_num_threads(6) # Optimal thread count for most systems
 
 
23
 
24
  return pipe
25
 
@@ -36,17 +42,19 @@ prompt = st.text_area(
36
  )
37
 
38
  if st.button("Generate Image"):
39
- with st.spinner("Creating your image... (about 1 minute)"):
40
  try:
41
- # Pre-optimized settings
42
- image = pipeline(
43
- prompt=prompt,
44
- negative_prompt="ugly, blurry, low quality, distorted, disfigured",
45
- num_inference_steps=20,
46
- guidance_scale=7.0,
47
- width=320,
48
- height=320,
49
- ).images[0]
 
 
50
 
51
  # Display and download
52
  st.image(image, caption=prompt, use_column_width=True)
 
5
 
6
  @st.cache_resource
7
  def load_model():
8
+ # Explicitly disable CUDA
9
+ torch.cuda.is_available = lambda : False
10
+
11
  # Use RunwayML's smaller model
12
  model_id = "runwayml/stable-diffusion-v1-5"
13
  pipe = StableDiffusionPipeline.from_pretrained(
 
16
  safety_checker=None,
17
  requires_safety_checker=False,
18
  use_safetensors=True
19
+ )
20
 
21
+ # Basic CPU optimizations
22
+ pipe = pipe.to("cpu")
23
  pipe.enable_attention_slicing(slice_size=1)
24
+ pipe.enable_vae_tiling()
25
+ torch.set_num_threads(4)
26
+
27
+ # Clear memory
28
+ gc.collect()
29
 
30
  return pipe
31
 
 
42
  )
43
 
44
  if st.button("Generate Image"):
45
+ with st.spinner("Creating your image... (1-2 minutes)"):
46
  try:
47
+ gc.collect()
48
+
49
+ with torch.inference_mode():
50
+ image = pipeline(
51
+ prompt=prompt,
52
+ negative_prompt="ugly, blurry, low quality, distorted, disfigured",
53
+ num_inference_steps=15, # Reduced for speed
54
+ guidance_scale=7.0,
55
+ width=256, # Reduced for memory
56
+ height=256, # Reduced for memory
57
+ ).images[0]
58
 
59
  # Display and download
60
  st.image(image, caption=prompt, use_column_width=True)