Oranblock commited on
Commit
5024e57
·
verified ·
1 Parent(s): 20b41d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -12,8 +12,9 @@ import torch
12
  from diffusers import DiffusionPipeline
13
  from typing import Tuple
14
 
15
- # Check if GPU is available; fallback to CPU if needed
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
17
 
18
  # Setup rules for bad words (ensure the prompts are kid-friendly)
19
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
@@ -66,18 +67,17 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str
66
  DESCRIPTION = """## Children's Sticker Generator
67
 
68
  Generate fun and playful stickers for children using AI.
69
- """
70
 
71
- if not torch.cuda.is_available():
72
- DESCRIPTION += "\n<p>⚠️Running on CPU, This may be slower.</p>"
73
 
74
  MAX_SEED = np.iinfo(np.int32).max
75
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
76
 
77
  # Initialize the DiffusionPipeline
78
  pipe = DiffusionPipeline.from_pretrained(
79
- "SG161222/RealVisXL_V3.0_Turbo",
80
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
81
  use_safetensors=True,
82
  ).to(device)
83
 
@@ -142,7 +142,7 @@ def generate(
142
  # Apply style
143
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
144
  seed = int(randomize_seed_fn(seed, randomize_seed))
145
- generator = torch.Generator(device=device).manual_seed(seed)
146
 
147
  width, height = size_map.get(size, (512, 512))
148
 
@@ -157,7 +157,7 @@ def generate(
157
  "guidance_scale": guidance_scale,
158
  "num_inference_steps": 20,
159
  "generator": generator,
160
- "num_images_per_prompt": 2,
161
  "output_type": "pil",
162
  }
163
 
@@ -196,7 +196,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
196
  container=False,
197
  )
198
  run_button = gr.Button("Run")
199
- result = gr.Gallery(label="Generated Stickers", columns=2, preview=True)
200
  error_output = gr.Textbox(label="Error", visible=False)
201
  with gr.Accordion("Advanced options", open=False):
202
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
@@ -236,7 +236,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
236
  minimum=0.1,
237
  maximum=20.0,
238
  step=0.1,
239
- value=15.7,
240
  )
241
 
242
  gr.Examples(
 
12
  from diffusers import DiffusionPipeline
13
  from typing import Tuple
14
 
15
+ # Force CPU usage
16
+ device = torch.device("cpu")
17
+ torch.cuda.is_available = lambda: False
18
 
19
  # Setup rules for bad words (ensure the prompts are kid-friendly)
20
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
 
67
  DESCRIPTION = """## Children's Sticker Generator
68
 
69
  Generate fun and playful stickers for children using AI.
 
70
 
71
+ ⚠️ Running on CPU. This may be slower.
72
+ """
73
 
74
  MAX_SEED = np.iinfo(np.int32).max
75
+ CACHE_EXAMPLES = False
76
 
77
  # Initialize the DiffusionPipeline
78
  pipe = DiffusionPipeline.from_pretrained(
79
+ "runwayml/stable-diffusion-v1-5", # Using a smaller model for CPU
80
+ torch_dtype=torch.float32,
81
  use_safetensors=True,
82
  ).to(device)
83
 
 
142
  # Apply style
143
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
144
  seed = int(randomize_seed_fn(seed, randomize_seed))
145
+ generator = torch.manual_seed(seed)
146
 
147
  width, height = size_map.get(size, (512, 512))
148
 
 
157
  "guidance_scale": guidance_scale,
158
  "num_inference_steps": 20,
159
  "generator": generator,
160
+ "num_images_per_prompt": 1, # Reduced to 1 for CPU
161
  "output_type": "pil",
162
  }
163
 
 
196
  container=False,
197
  )
198
  run_button = gr.Button("Run")
199
+ result = gr.Gallery(label="Generated Stickers", columns=1, preview=True)
200
  error_output = gr.Textbox(label="Error", visible=False)
201
  with gr.Accordion("Advanced options", open=False):
202
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
 
236
  minimum=0.1,
237
  maximum=20.0,
238
  step=0.1,
239
+ value=7.5,
240
  )
241
 
242
  gr.Examples(