aroffe commited on
Commit
86775a6
·
verified ·
1 Parent(s): f9110f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -1,16 +1,24 @@
1
  import gradio as gr
2
  import os
3
 
4
- from huggingface_hub import HfApi, ModelFilter, list_liked_repos
5
  from diffusers import DiffusionPipeline
6
  import torch
7
 
 
 
 
 
 
 
 
8
  def image_mod(prompt: str, model: str, image_0: gr.Image, image_1: gr.Image) -> list[gr.Image]:
 
9
  images = [image_0, image_1]
10
  for i, diffusion_model in enumerate(model):
11
  pipeline = DiffusionPipeline.from_pretrained(
12
  pretrained_model_name_or_path=diffusion_model,
13
- torch_dtype=torch.float16,
14
  use_safetensors=True,
15
  device_map="auto"
16
  )
 
1
  import gradio as gr
2
  import os
3
 
4
+ from huggingface_hub import HfApi, ModelFilter, list_liked_repos, SpaceHardware
5
  from diffusers import DiffusionPipeline
6
  import torch
7
 
8
+ def gpu_enabled() -> bool:
9
+ # If cloned, fill in SPACE_ID with your own space
10
+ SPACE_ID = "aroffe/comparing-diffusion-models"
11
+ runtime = api.get_space_runtime(repo_id=SPACE_ID)
12
+ return runtime.hardware != (SpaceHardware.CPU_BASIC || SpaceHardware.CPU_UPGRADE)
13
+
14
+
15
  def image_mod(prompt: str, model: str, image_0: gr.Image, image_1: gr.Image) -> list[gr.Image]:
16
+ gpu_enabled()
17
  images = [image_0, image_1]
18
  for i, diffusion_model in enumerate(model):
19
  pipeline = DiffusionPipeline.from_pretrained(
20
  pretrained_model_name_or_path=diffusion_model,
21
+ torch_dtype=torch.float16 if gpu_enabled() else torch.float32,
22
  use_safetensors=True,
23
  device_map="auto"
24
  )