Spaces:
Runtime error
Runtime error
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -66,7 +66,6 @@ def layout(*args):
|
|
| 66 |
st.markdown(str(foot), unsafe_allow_html=True)
|
| 67 |
|
| 68 |
|
| 69 |
-
|
| 70 |
def footer():
|
| 71 |
myargs = [
|
| 72 |
"Created by ",
|
|
@@ -96,7 +95,6 @@ def footer():
|
|
| 96 |
height=600,
|
| 97 |
)
|
| 98 |
|
| 99 |
-
|
| 100 |
model = False
|
| 101 |
def generate(prompt,crazy,k):
|
| 102 |
global model
|
|
@@ -113,7 +111,11 @@ def generate(prompt,crazy,k):
|
|
| 113 |
set_seed(np.random.randint(0,10000))
|
| 114 |
|
| 115 |
# Sampling
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
top_k=2048,
|
| 118 |
top_p=None,
|
| 119 |
softmax_temperature=crazy,
|
|
@@ -124,7 +126,7 @@ def generate(prompt,crazy,k):
|
|
| 124 |
# CLIP Re-ranking
|
| 125 |
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
|
| 126 |
model_clip.to(device=device)
|
| 127 |
-
rank = clip_score(prompt=
|
| 128 |
images=images,
|
| 129 |
model_clip=model_clip,
|
| 130 |
preprocess_clip=preprocess_clip,
|
|
@@ -143,35 +145,37 @@ def generate(prompt,crazy,k):
|
|
| 143 |
|
| 144 |
def drawGrid():
|
| 145 |
master = {}
|
| 146 |
-
order = 0
|
| 147 |
-
|
| 148 |
-
#print(st.session_state.results)
|
| 149 |
|
| 150 |
for r in st.session_state.results[::-1]:
|
| 151 |
_txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
|
| 152 |
-
|
| 153 |
if(_txt not in master):
|
| 154 |
master[_txt] = [r]
|
| 155 |
-
order += 1
|
| 156 |
else:
|
| 157 |
master[_txt].append(r)
|
| 158 |
|
| 159 |
-
|
| 160 |
-
for
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
| 177 |
|
|
|
|
| 66 |
st.markdown(str(foot), unsafe_allow_html=True)
|
| 67 |
|
| 68 |
|
|
|
|
| 69 |
def footer():
|
| 70 |
myargs = [
|
| 71 |
"Created by ",
|
|
|
|
| 95 |
height=600,
|
| 96 |
)
|
| 97 |
|
|
|
|
| 98 |
model = False
|
| 99 |
def generate(prompt,crazy,k):
|
| 100 |
global model
|
|
|
|
| 111 |
set_seed(np.random.randint(0,10000))
|
| 112 |
|
| 113 |
# Sampling
|
| 114 |
+
newPrompt = prompt
|
| 115 |
+
if("architecture" not in prompt.lower() ):
|
| 116 |
+
newPrompt += " architecture"
|
| 117 |
+
|
| 118 |
+
images = model.sampling(prompt=newPrompt,
|
| 119 |
top_k=2048,
|
| 120 |
top_p=None,
|
| 121 |
softmax_temperature=crazy,
|
|
|
|
| 126 |
# CLIP Re-ranking
|
| 127 |
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
|
| 128 |
model_clip.to(device=device)
|
| 129 |
+
rank = clip_score(prompt=newPrompt,
|
| 130 |
images=images,
|
| 131 |
model_clip=model_clip,
|
| 132 |
preprocess_clip=preprocess_clip,
|
|
|
|
| 145 |
|
| 146 |
def drawGrid():
|
| 147 |
master = {}
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
for r in st.session_state.results[::-1]:
|
| 150 |
_txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
|
|
|
|
| 151 |
if(_txt not in master):
|
| 152 |
master[_txt] = [r]
|
|
|
|
| 153 |
else:
|
| 154 |
master[_txt].append(r)
|
| 155 |
|
| 156 |
+
|
| 157 |
+
for i in st.session_state.images:
|
| 158 |
+
im = st.empty()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
placeholder = st.empty()
|
| 162 |
+
with placeholder.container():
|
| 163 |
+
|
| 164 |
+
for m in master:
|
| 165 |
|
| 166 |
+
txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")"
|
| 167 |
+
st.subheader(txt)
|
| 168 |
+
col1, col2, col3 = st.columns(3)
|
| 169 |
+
|
| 170 |
+
for ix, item in enumerate(master[m]):
|
| 171 |
+
if ix % 3 == 0:
|
| 172 |
+
with col1:
|
| 173 |
+
st.session_state.images.append(st.image(item["image"]))
|
| 174 |
+
if ix % 3 == 1:
|
| 175 |
+
with col2:
|
| 176 |
+
st.session_state.images.append(st.image(item["image"]))
|
| 177 |
+
if ix % 3 == 2:
|
| 178 |
+
with col3:
|
| 179 |
+
st.session_state.images.append(st.image(item["image"]))
|
| 180 |
+
|
| 181 |
|