1inkusFace commited on
Commit
503fc45
·
verified ·
1 Parent(s): 2e8fcf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -9,7 +9,7 @@ from image_gen_aux import UpscaleWithModel
9
  import cyper
10
  from PIL import Image
11
  os.environ['JAX_PLATFORMS'] = 'cpu'
12
-
13
  os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
14
  import keras
15
  import keras_hub
@@ -28,14 +28,15 @@ torch.backends.cudnn.benchmark = False
28
  torch.set_float32_matmul_precision("highest")
29
 
30
  upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device)
 
31
 
32
  def load_model():
 
33
  text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
34
  "stable_diffusion_3_medium", width=768, height=768, dtype="bfloat16"
35
  )
36
  return text_to_image
37
 
38
-
39
  code = r'''
40
  import paramiko
41
  import os
@@ -57,7 +58,6 @@ def upload_to_ftp(filename):
57
  print(f"FTP upload error: {e}")
58
  '''
59
 
60
-
61
  pyx = cyper.inline(code, fast_indexing=True, directives=dict(boundscheck=False, wraparound=False, language_level=3))
62
 
63
  MAX_SEED = np.iinfo(np.int32).max
@@ -72,7 +72,9 @@ def infer_30(
72
  num_inference_steps,
73
  progress=gr.Progress(track_tqdm=True),
74
  ):
75
- text_to_image = load_model()
 
 
76
  os.environ['JAX_PLATFORMS'] = 'gpu'
77
  os.environ['KERAS_BACKEND'] = 'jax'
78
  seed = random.randint(0, MAX_SEED)
@@ -104,7 +106,9 @@ def infer_60(
104
  num_inference_steps,
105
  progress=gr.Progress(track_tqdm=True),
106
  ):
107
- text_to_image = load_model()
 
 
108
  os.environ['JAX_PLATFORMS'] = 'gpu'
109
  os.environ['KERAS_BACKEND'] = 'jax'
110
  seed = random.randint(0, MAX_SEED)
@@ -136,7 +140,9 @@ def infer_90(
136
  num_inference_steps,
137
  progress=gr.Progress(track_tqdm=True),
138
  ):
139
- text_to_image = load_model()
 
 
140
  os.environ['JAX_PLATFORMS'] = 'gpu'
141
  os.environ['KERAS_BACKEND'] = 'jax'
142
  seed = random.randint(0, MAX_SEED)
@@ -178,6 +184,7 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
178
  placeholder="Enter your prompt",
179
  container=False,
180
  )
 
181
  run_button_30 = gr.Button("Run 30", scale=0, variant="primary")
182
  run_button_60 = gr.Button("Run 60", scale=0, variant="primary")
183
  run_button_90 = gr.Button("Run 90", scale=0, variant="primary")
@@ -206,6 +213,13 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
206
  value=50,
207
  )
208
 
 
 
 
 
 
 
 
209
  gr.on(
210
  triggers=[run_button_30.click, prompt.submit],
211
  fn=infer_30,
 
9
  import cyper
10
  from PIL import Image
11
  os.environ['JAX_PLATFORMS'] = 'cpu'
12
+ import random
13
  os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
14
  import keras
15
  import keras_hub
 
28
  torch.set_float32_matmul_precision("highest")
29
 
30
  upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device)
31
+ text_to_image = None
32
 
33
  def load_model():
34
+ global text_to_image
35
  text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
36
  "stable_diffusion_3_medium", width=768, height=768, dtype="bfloat16"
37
  )
38
  return text_to_image
39
 
 
40
  code = r'''
41
  import paramiko
42
  import os
 
58
  print(f"FTP upload error: {e}")
59
  '''
60
 
 
61
  pyx = cyper.inline(code, fast_indexing=True, directives=dict(boundscheck=False, wraparound=False, language_level=3))
62
 
63
  MAX_SEED = np.iinfo(np.int32).max
 
72
  num_inference_steps,
73
  progress=gr.Progress(track_tqdm=True),
74
  ):
75
+ global text_to_image
76
+ if text_to_image is None:
77
+ text_to_image = load_model()
78
  os.environ['JAX_PLATFORMS'] = 'gpu'
79
  os.environ['KERAS_BACKEND'] = 'jax'
80
  seed = random.randint(0, MAX_SEED)
 
106
  num_inference_steps,
107
  progress=gr.Progress(track_tqdm=True),
108
  ):
109
+ global text_to_image
110
+ if text_to_image is None:
111
+ text_to_image = load_model()
112
  os.environ['JAX_PLATFORMS'] = 'gpu'
113
  os.environ['KERAS_BACKEND'] = 'jax'
114
  seed = random.randint(0, MAX_SEED)
 
140
  num_inference_steps,
141
  progress=gr.Progress(track_tqdm=True),
142
  ):
143
+ global text_to_image
144
+ if text_to_image is None:
145
+ text_to_image = load_model()
146
  os.environ['JAX_PLATFORMS'] = 'gpu'
147
  os.environ['KERAS_BACKEND'] = 'jax'
148
  seed = random.randint(0, MAX_SEED)
 
184
  placeholder="Enter your prompt",
185
  container=False,
186
  )
187
+ load_button = gr.Button("Load model", scale=0, variant="primary")
188
  run_button_30 = gr.Button("Run 30", scale=0, variant="primary")
189
  run_button_60 = gr.Button("Run 60", scale=0, variant="primary")
190
  run_button_90 = gr.Button("Run 90", scale=0, variant="primary")
 
213
  value=50,
214
  )
215
 
216
+ gr.on(
217
+ triggers=[load_button.click],
218
+ fn=load_model,
219
+ inputs=[],
220
+ outputs=[],
221
+ )
222
+
223
  gr.on(
224
  triggers=[run_button_30.click, prompt.submit],
225
  fn=infer_30,