Commit ·
db58866
1
Parent(s): c34bc8e
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,7 +21,7 @@ word_list = word_list_dataset["train"]['text']
|
|
| 21 |
|
| 22 |
is_gpu_busy = False
|
| 23 |
def infer(prompt):
|
| 24 |
-
|
| 25 |
samples = 4
|
| 26 |
steps = 50
|
| 27 |
scale = 7.5
|
|
@@ -30,31 +30,31 @@ def infer(prompt):
|
|
| 30 |
if re.search(rf"\b{filter}\b", prompt):
|
| 31 |
raise gr.Error("Unsafe content found. Please try again with different prompts.")
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
images = []
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
#generator=generator,
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
return images
|
| 60 |
|
|
|
|
| 21 |
|
| 22 |
is_gpu_busy = False
|
| 23 |
def infer(prompt):
|
| 24 |
+
global is_gpu_busy
|
| 25 |
samples = 4
|
| 26 |
steps = 50
|
| 27 |
scale = 7.5
|
|
|
|
| 30 |
if re.search(rf"\b{filter}\b", prompt):
|
| 31 |
raise gr.Error("Unsafe content found. Please try again with different prompts.")
|
| 32 |
|
| 33 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 34 |
+
print("Is GPU busy? ", is_gpu_busy)
|
| 35 |
images = []
|
| 36 |
+
if(not is_gpu_busy):
|
| 37 |
+
is_gpu_busy = True
|
| 38 |
+
images_list = pipe(
|
| 39 |
+
[prompt] * samples,
|
| 40 |
+
num_inference_steps=steps,
|
| 41 |
+
guidance_scale=scale,
|
| 42 |
#generator=generator,
|
| 43 |
+
)
|
| 44 |
+
is_gpu_busy = False
|
| 45 |
+
safe_image = Image.open(r"unsafe.png")
|
| 46 |
+
for i, image in enumerate(images_list["sample"]):
|
| 47 |
+
if(images_list["nsfw_content_detected"][i]):
|
| 48 |
+
images.append(safe_image)
|
| 49 |
+
else:
|
| 50 |
+
images.append(image)
|
| 51 |
+
else:
|
| 52 |
+
url = os.getenv('JAX_BACKEND_URL')
|
| 53 |
+
payload = {'prompt': prompt}
|
| 54 |
+
images_request = requests.post(url, json = payload)
|
| 55 |
+
for image in images_request.json()["images"]:
|
| 56 |
+
image_b64 = (f"data:image/jpeg;base64,{image}")
|
| 57 |
+
images.append(image_b64)
|
| 58 |
|
| 59 |
return images
|
| 60 |
|