Eeman Majumder commited on
Commit
fae8c2a
·
1 Parent(s): e3ebf48
Files changed (1) hide show
  1. app.py +79 -19
app.py CHANGED
@@ -2,30 +2,90 @@ from click import command
2
  import streamlit as st
3
  import os
4
  from authtoken import auth_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
- from diffusers import StableDiffusionPipeline
8
- st.session_state.ac=True
9
- if st.session_state.ac==True:
10
- os.system('pip3 install torch torchvision torchaudio')
11
- import torch
12
- from torch import autocast
13
- st.session_state.ac=False
14
 
15
- st.title('Stable Diffusion (AI that can imagine)')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
 
19
- prompt=st.text_input('Enter the prompt')
20
- modelid = "CompVis/stable-diffusion-v1-4"
21
- device = "cuda"
22
- pipe = StableDiffusionPipeline.from_pretrained(modelid, revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token)
23
- pipe.to(device)
24
- def generate():
25
- with autocast(device):
26
- image = pipe(prompt.get(), guidance_scale=8.5)["sample"][0]
27
 
28
- image.save(prompt+'.png')
29
- img = st.image(prompt+'.png')
30
 
31
- st.button('Generate',command=generate)
 
2
  import streamlit as st
3
  import os
4
  from authtoken import auth_token
5
+ from datasets import load_dataset
6
+ from PIL import Image
7
+ import re
8
+ import os
9
+ import requests
10
+
11
+ model_id = "CompVis/stable-diffusion-v1-4"
12
+ device = "cuda"
13
+ word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
14
+ word_list = word_list_dataset["train"]['text']
15
+
16
+ is_gpu_busy = False
17
+ def infer(prompt):
18
+ global is_gpu_busy
19
+ samples = 4
20
+ steps = 50
21
+ scale = 7.5
22
+ #When running locally you can also remove this filter
23
+ for filter in word_list:
24
+ if re.search(rf"\b{filter}\b", prompt):
25
+ raise st.Error("Unsafe content found. Please try again with different prompts.")
26
+ images = []
27
+ url = os.getenv('JAX_BACKEND_URL')
28
+ payload = {'prompt': prompt}
29
+ images_request = requests.post(url, json = payload)
30
+ for image in images_request.json()["images"]:
31
+ image_b64 = (f"data:image/jpeg;base64,{image}")
32
+ images.append(image_b64)
33
+ st.image(images)
34
+ return images
35
+
36
+ prompt=st.text_input('Enter your prompt')
37
+ true=st.button('Generate')
38
+ if true==True:
39
+ infer(prompt)
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
 
52
 
 
 
 
 
 
 
 
53
 
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+ # from diffusers import StableDiffusionPipeline
68
+ # st.session_state.ac=True
69
+ # if st.session_state.ac==True:
70
+ # os.system('pip3 install torch torchvision torchaudio')
71
+ # import torch
72
+ # from torch import autocast
73
+ # st.session_state.ac=False
74
+
75
+ # st.title('Stable Diffusion (AI that can imagine)')
76
 
77
 
78
 
79
+ # prompt=st.text_input('Enter the prompt')
80
+ # modelid = "CompVis/stable-diffusion-v1-4"
81
+ # device = "cuda"
82
+ # pipe = StableDiffusionPipeline.from_pretrained(modelid, revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token)
83
+ # pipe.to(device)
84
+ # def generate():
85
+ # with autocast(device):
86
+ # image = pipe(prompt.get(), guidance_scale=8.5)["sample"][0]
87
 
88
+ # image.save(prompt+'.png')
89
+ # img = st.image(prompt+'.png')
90
 
91
+ # st.button('Generate',command=generate)