Spaces:
Runtime error
Runtime error
cwkuo
commited on
Commit
·
6b2ffd3
1
Parent(s):
ef2dc13
tune some default params
Browse files- app.py +33 -41
- examples/diamond_head.jpg +0 -3
app.py
CHANGED
|
@@ -159,7 +159,7 @@ def retrieve_knowledge(image):
|
|
| 159 |
|
| 160 |
|
| 161 |
@torch.inference_mode()
|
| 162 |
-
def generate(state: Conversation, temperature, top_p, max_new_tokens
|
| 163 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
| 164 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
|
| 165 |
return
|
|
@@ -172,37 +172,33 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
|
|
| 172 |
|
| 173 |
# retrieve and visualize knowledge
|
| 174 |
image = state.get_images(return_pil=True)[0]
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
knwl_vis = tuple(knwl_img + knwl_txt)
|
| 203 |
-
else:
|
| 204 |
-
knwl_embd = None
|
| 205 |
-
knwl_vis = knwl_none
|
| 206 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_vis
|
| 207 |
|
| 208 |
# generate output
|
|
@@ -217,7 +213,7 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
|
|
| 217 |
target=gptk_model.generate,
|
| 218 |
kwargs=dict(
|
| 219 |
samples=samples,
|
| 220 |
-
use_nucleus_sampling=
|
| 221 |
max_length=min(int(max_new_tokens), 1024),
|
| 222 |
top_p=float(top_p),
|
| 223 |
temperature=float(temperature),
|
|
@@ -270,7 +266,6 @@ def build_demo():
|
|
| 270 |
gr.Examples(examples=[
|
| 271 |
["examples/mona_lisa.jpg", "Discuss the historical impact and the significance of this painting in the art world."],
|
| 272 |
["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
|
| 273 |
-
["examples/diamond_head.jpg", "What is the name of this famous sight in the photo?"],
|
| 274 |
["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
|
| 275 |
], inputs=[imagebox, textbox])
|
| 276 |
|
|
@@ -286,10 +281,7 @@ def build_demo():
|
|
| 286 |
clear_btn = gr.Button(value="🗑️ Clear", interactive=False, scale=1)
|
| 287 |
|
| 288 |
with gr.Accordion("Parameters", open=True):
|
| 289 |
-
|
| 290 |
-
add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
|
| 291 |
-
do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
|
| 292 |
-
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
|
| 293 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
| 294 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
| 295 |
|
|
@@ -318,7 +310,7 @@ def build_demo():
|
|
| 318 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
| 319 |
).then(
|
| 320 |
generate,
|
| 321 |
-
[state, temperature, top_p, max_output_tokens
|
| 322 |
[state, chatbot] + btn_list + knwl_vis
|
| 323 |
)
|
| 324 |
|
|
@@ -330,7 +322,7 @@ def build_demo():
|
|
| 330 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
| 331 |
).then(
|
| 332 |
generate,
|
| 333 |
-
[state, temperature, top_p, max_output_tokens
|
| 334 |
[state, chatbot] + btn_list + knwl_vis
|
| 335 |
)
|
| 336 |
|
|
@@ -338,7 +330,7 @@ def build_demo():
|
|
| 338 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
| 339 |
).then(
|
| 340 |
generate,
|
| 341 |
-
[state, temperature, top_p, max_output_tokens
|
| 342 |
[state, chatbot] + btn_list + knwl_vis
|
| 343 |
)
|
| 344 |
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
@torch.inference_mode()
|
| 162 |
+
def generate(state: Conversation, temperature, top_p, max_new_tokens):
|
| 163 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
| 164 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
|
| 165 |
return
|
|
|
|
| 172 |
|
| 173 |
# retrieve and visualize knowledge
|
| 174 |
image = state.get_images(return_pil=True)[0]
|
| 175 |
+
knwl_embd, knwl = retrieve_knowledge(image)
|
| 176 |
+
knwl_img, knwl_txt, idx = [None, ] * 15, ["", ] * 15, 0
|
| 177 |
+
for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
|
| 178 |
+
if query_type == "whole":
|
| 179 |
+
images = [image, ]
|
| 180 |
+
elif query_type == "five":
|
| 181 |
+
images = five_crop(image)
|
| 182 |
+
elif query_type == "nine":
|
| 183 |
+
images = nine_crop(image)
|
| 184 |
+
|
| 185 |
+
for pos in range(knwl_pos):
|
| 186 |
+
try:
|
| 187 |
+
txt = ""
|
| 188 |
+
for k, v in knwl[query_type][pos].items():
|
| 189 |
+
v = ", ".join([vi.replace("_", " ") for vi in v])
|
| 190 |
+
txt += f"**[{k.upper()}]:** {v}\n\n"
|
| 191 |
+
knwl_txt[idx] += txt
|
| 192 |
+
|
| 193 |
+
img = images[pos]
|
| 194 |
+
img = query_trans.transforms[0](img)
|
| 195 |
+
img = query_trans.transforms[1](img)
|
| 196 |
+
img = query_trans.transforms[2](img)
|
| 197 |
+
knwl_img[idx] = img
|
| 198 |
+
except KeyError:
|
| 199 |
+
pass
|
| 200 |
+
idx += 1
|
| 201 |
+
knwl_vis = tuple(knwl_img + knwl_txt)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_vis
|
| 203 |
|
| 204 |
# generate output
|
|
|
|
| 213 |
target=gptk_model.generate,
|
| 214 |
kwargs=dict(
|
| 215 |
samples=samples,
|
| 216 |
+
use_nucleus_sampling=(temperature > 0.001),
|
| 217 |
max_length=min(int(max_new_tokens), 1024),
|
| 218 |
top_p=float(top_p),
|
| 219 |
temperature=float(temperature),
|
|
|
|
| 266 |
gr.Examples(examples=[
|
| 267 |
["examples/mona_lisa.jpg", "Discuss the historical impact and the significance of this painting in the art world."],
|
| 268 |
["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
|
|
|
|
| 269 |
["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
|
| 270 |
], inputs=[imagebox, textbox])
|
| 271 |
|
|
|
|
| 281 |
clear_btn = gr.Button(value="🗑️ Clear", interactive=False, scale=1)
|
| 282 |
|
| 283 |
with gr.Accordion("Parameters", open=True):
|
| 284 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature",)
|
|
|
|
|
|
|
|
|
|
| 285 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
| 286 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
| 287 |
|
|
|
|
| 310 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
| 311 |
).then(
|
| 312 |
generate,
|
| 313 |
+
[state, temperature, top_p, max_output_tokens],
|
| 314 |
[state, chatbot] + btn_list + knwl_vis
|
| 315 |
)
|
| 316 |
|
|
|
|
| 322 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
| 323 |
).then(
|
| 324 |
generate,
|
| 325 |
+
[state, temperature, top_p, max_output_tokens],
|
| 326 |
[state, chatbot] + btn_list + knwl_vis
|
| 327 |
)
|
| 328 |
|
|
|
|
| 330 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
| 331 |
).then(
|
| 332 |
generate,
|
| 333 |
+
[state, temperature, top_p, max_output_tokens],
|
| 334 |
[state, chatbot] + btn_list + knwl_vis
|
| 335 |
)
|
| 336 |
|
examples/diamond_head.jpg
DELETED
Git LFS Details
|