Update app.py
Browse files
app.py
CHANGED
|
@@ -8,9 +8,9 @@ import paramiko
|
|
| 8 |
from image_gen_aux import UpscaleWithModel
|
| 9 |
import cyper
|
| 10 |
from PIL import Image
|
|
|
|
| 11 |
|
| 12 |
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
|
| 13 |
-
os.environ['KERAS_BACKEND'] = 'jax'
|
| 14 |
import keras
|
| 15 |
import keras_hub
|
| 16 |
import torch
|
|
@@ -73,6 +73,8 @@ def infer_30(
|
|
| 73 |
num_inference_steps,
|
| 74 |
progress=gr.Progress(track_tqdm=True),
|
| 75 |
):
|
|
|
|
|
|
|
| 76 |
seed = random.randint(0, MAX_SEED)
|
| 77 |
sd_image = text_to_image.generate(
|
| 78 |
prompt=prompt,
|
|
@@ -102,6 +104,8 @@ def infer_60(
|
|
| 102 |
num_inference_steps,
|
| 103 |
progress=gr.Progress(track_tqdm=True),
|
| 104 |
):
|
|
|
|
|
|
|
| 105 |
seed = random.randint(0, MAX_SEED)
|
| 106 |
sd_image = text_to_image.generate(
|
| 107 |
prompt=prompt,
|
|
@@ -131,6 +135,8 @@ def infer_90(
|
|
| 131 |
num_inference_steps,
|
| 132 |
progress=gr.Progress(track_tqdm=True),
|
| 133 |
):
|
|
|
|
|
|
|
| 134 |
seed = random.randint(0, MAX_SEED)
|
| 135 |
sd_image = text_to_image.generate(
|
| 136 |
prompt=prompt,
|
|
|
|
| 8 |
from image_gen_aux import UpscaleWithModel
|
| 9 |
import cyper
|
| 10 |
from PIL import Image
|
| 11 |
+
os.environ['JAX_PLATFORMS'] = 'cpu'
|
| 12 |
|
| 13 |
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
|
|
|
|
| 14 |
import keras
|
| 15 |
import keras_hub
|
| 16 |
import torch
|
|
|
|
| 73 |
num_inference_steps,
|
| 74 |
progress=gr.Progress(track_tqdm=True),
|
| 75 |
):
|
| 76 |
+
os.environ['JAX_PLATFORMS'] = 'gpu'
|
| 77 |
+
os.environ['KERAS_BACKEND'] = 'jax'
|
| 78 |
seed = random.randint(0, MAX_SEED)
|
| 79 |
sd_image = text_to_image.generate(
|
| 80 |
prompt=prompt,
|
|
|
|
| 104 |
num_inference_steps,
|
| 105 |
progress=gr.Progress(track_tqdm=True),
|
| 106 |
):
|
| 107 |
+
os.environ['JAX_PLATFORMS'] = 'gpu'
|
| 108 |
+
os.environ['KERAS_BACKEND'] = 'jax'
|
| 109 |
seed = random.randint(0, MAX_SEED)
|
| 110 |
sd_image = text_to_image.generate(
|
| 111 |
prompt=prompt,
|
|
|
|
| 135 |
num_inference_steps,
|
| 136 |
progress=gr.Progress(track_tqdm=True),
|
| 137 |
):
|
| 138 |
+
os.environ['JAX_PLATFORMS'] = 'gpu'
|
| 139 |
+
os.environ['KERAS_BACKEND'] = 'jax'
|
| 140 |
seed = random.randint(0, MAX_SEED)
|
| 141 |
sd_image = text_to_image.generate(
|
| 142 |
prompt=prompt,
|