JUNGU commited on
Commit
5ee3e5b
·
1 Parent(s): 18fd32d
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install --upgrade keras-cv
2
+ !pip install translate
3
+
4
+ import gradio as gr
5
+ import time
6
+ import keras_cv
7
+ from tensorflow import keras
8
+ import matplotlib.pyplot as plt
9
+ from translate import Translator
10
+
11
+ keras.mixed_precision.set_global_policy("mixed_float16")
12
+ model = keras_cv.models.StableDiffusion(img_width=512, img_height=512, jit_compile=True)
13
+
14
+ def plot_images(images):
15
+ plt.figure(figsize=(10, 10))
16
+ for i in range(len(images)):
17
+ ax = plt.subplot(1, len(images), i + 1)
18
+ plt.imshow(images[i])
19
+ plt.axis("off")
20
+ plt.tight_layout()
21
+
22
+ images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)
23
+ translator = Translator(from_lang="ko", to_lang="en")
24
+
25
+ def generate_images(text, n=3):
26
+ print(text)
27
+ translation = translator.translate(text)
28
+ print(translation)
29
+ images = model.text_to_image(translation, batch_size=n)
30
+ return images
31
+
32
+ def inference(text):
33
+ image = generate_images(text, 1).squeeze()
34
+ return image
35
+
36
+ demo = gr.Interface(fn=inference, inputs="text", outputs="image")
37
+
38
+ demo.launch(share=True)