1inkusFace commited on
Commit
e011c4f
·
verified ·
1 Parent(s): 5c1ebd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -8,9 +8,9 @@ import paramiko
8
  from image_gen_aux import UpscaleWithModel
9
  import cyper
10
  from PIL import Image
 
11
 
12
  os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
13
- os.environ['KERAS_BACKEND'] = 'jax'
14
  import keras
15
  import keras_hub
16
  import torch
@@ -73,6 +73,8 @@ def infer_30(
73
  num_inference_steps,
74
  progress=gr.Progress(track_tqdm=True),
75
  ):
 
 
76
  seed = random.randint(0, MAX_SEED)
77
  sd_image = text_to_image.generate(
78
  prompt=prompt,
@@ -102,6 +104,8 @@ def infer_60(
102
  num_inference_steps,
103
  progress=gr.Progress(track_tqdm=True),
104
  ):
 
 
105
  seed = random.randint(0, MAX_SEED)
106
  sd_image = text_to_image.generate(
107
  prompt=prompt,
@@ -131,6 +135,8 @@ def infer_90(
131
  num_inference_steps,
132
  progress=gr.Progress(track_tqdm=True),
133
  ):
 
 
134
  seed = random.randint(0, MAX_SEED)
135
  sd_image = text_to_image.generate(
136
  prompt=prompt,
 
8
  from image_gen_aux import UpscaleWithModel
9
  import cyper
10
  from PIL import Image
11
+ os.environ['JAX_PLATFORMS'] = 'cpu'
12
 
13
  os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
 
14
  import keras
15
  import keras_hub
16
  import torch
 
73
  num_inference_steps,
74
  progress=gr.Progress(track_tqdm=True),
75
  ):
76
+ os.environ['JAX_PLATFORMS'] = 'gpu'
77
+ os.environ['KERAS_BACKEND'] = 'jax'
78
  seed = random.randint(0, MAX_SEED)
79
  sd_image = text_to_image.generate(
80
  prompt=prompt,
 
104
  num_inference_steps,
105
  progress=gr.Progress(track_tqdm=True),
106
  ):
107
+ os.environ['JAX_PLATFORMS'] = 'gpu'
108
+ os.environ['KERAS_BACKEND'] = 'jax'
109
  seed = random.randint(0, MAX_SEED)
110
  sd_image = text_to_image.generate(
111
  prompt=prompt,
 
135
  num_inference_steps,
136
  progress=gr.Progress(track_tqdm=True),
137
  ):
138
+ os.environ['JAX_PLATFORMS'] = 'gpu'
139
+ os.environ['KERAS_BACKEND'] = 'jax'
140
  seed = random.randint(0, MAX_SEED)
141
  sd_image = text_to_image.generate(
142
  prompt=prompt,