Eeman Majumder commited on
Commit
29b2cf7
·
1 Parent(s): a89aa16
Files changed (1) hide show
  1. app.py +41 -11
app.py CHANGED
@@ -1,5 +1,8 @@
1
  from click import command
2
  import streamlit as st
 
 
 
3
  import os
4
  from authtoken import auth_token
5
  from datasets import load_dataset
@@ -13,30 +16,57 @@ device = "cuda"
13
  # word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
14
  # word_list = word_list_dataset["train"]['text']
15
 
 
 
 
 
16
  is_gpu_busy = False
 
17
  def infer(prompt):
18
  global is_gpu_busy
19
  samples = 4
20
  steps = 50
21
  scale = 7.5
22
- # #When running locally you can also remove this filter
23
  # for filter in word_list:
24
  # if re.search(rf"\b{filter}\b", prompt):
25
- # raise st.Error("Unsafe content found. Please try again with different prompts.")
 
 
 
26
  images = []
27
- url = os.getenv('JAX_BACKEND_URL')
28
- payload = {'prompt': prompt}
29
- images_request = requests.post(url, json = payload)
30
- for image in images_request.json()["images"]:
31
- image_b64 = (f"data:image/jpeg;base64,{image}")
32
- images.append(image_b64)
33
- st.image(images)
34
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  prompt=st.text_input('Enter your prompt')
37
  true=st.button('Generate')
38
  if true==True:
39
- infer(prompt)
 
 
40
 
41
 
42
 
 
1
  from click import command
2
  import streamlit as st
3
+ import torch
4
+ from torch import autocast
5
+ from diffusers import StableDiffusionPipeline
6
  import os
7
  from authtoken import auth_token
8
  from datasets import load_dataset
 
16
  # word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
17
  # word_list = word_list_dataset["train"]['text']
18
 
19
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16)
20
+ pipe = pipe.to(device)
21
+ torch.backends.cudnn.benchmark = True
22
+
23
  is_gpu_busy = False
24
+
25
  def infer(prompt):
26
  global is_gpu_busy
27
  samples = 4
28
  steps = 50
29
  scale = 7.5
30
+ #When running locally you can also remove this filter
31
  # for filter in word_list:
32
  # if re.search(rf"\b{filter}\b", prompt):
33
+ # raise gr.Error("Unsafe content found. Please try again with different prompts.")
34
+
35
+ generator = torch.Generator(device=device).manual_seed(seed)
36
+ #print("Is GPU busy? ", is_gpu_busy)
37
  images = []
38
+ if(not is_gpu_busy):
39
+ is_gpu_busy = True
40
+ images_list = pipe(
41
+ [prompt] * samples,
42
+ num_inference_steps=steps,
43
+ guidance_scale=scale,
44
+ generator=generator,
45
+ )
46
+ is_gpu_busy = False
47
+ safe_image = Image.open(r"unsafe.png")
48
+ for i, image in enumerate(images_list["sample"]):
49
+ if(images_list["nsfw_content_detected"][i]):
50
+ images.append(safe_image)
51
+ else:
52
+ images.append(image)
53
+
54
+ # else:
55
+ # url = os.getenv('JAX_BACKEND_URL')
56
+ # payload = {'prompt': prompt}
57
+ # images_request = requests.post(url, json = payload)
58
+ # for image in images_request.json()["images"]:
59
+ # image_b64 = (f"data:image/jpeg;base64,{image}")
60
+ # images.append(image_b64)
61
+
62
+ return images
63
 
64
  prompt=st.text_input('Enter your prompt')
65
  true=st.button('Generate')
66
  if true==True:
67
+ a=infer(prompt)
68
+ for i in a:
69
+ st.image(a)
70
 
71
 
72