sneha commited on
Commit ·
8946de5
1
Parent(s): 70e8264
default radio option
Browse files
app.py
CHANGED
|
@@ -43,7 +43,7 @@ def download_bin():
|
|
| 43 |
os.rename(model_bin, bin_path)
|
| 44 |
|
| 45 |
|
| 46 |
-
def run_attn(input_img,fusion
|
| 47 |
download_bin()
|
| 48 |
model, embedding_dim, transform, metadata = get_model()
|
| 49 |
if input_img.shape[0] != 3:
|
|
@@ -69,7 +69,7 @@ def run_attn(input_img,fusion="min"):
|
|
| 69 |
return attn_img, fig
|
| 70 |
|
| 71 |
input_img = gr.Image(shape=(250,250))
|
| 72 |
-
input_button = gr.Radio(["min", "max", "mean"], label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
|
| 73 |
output_img = gr.Image(shape=(250,250))
|
| 74 |
output_plot = gr.Plot()
|
| 75 |
|
|
|
|
| 43 |
os.rename(model_bin, bin_path)
|
| 44 |
|
| 45 |
|
| 46 |
+
def run_attn(input_img,fusion):
|
| 47 |
download_bin()
|
| 48 |
model, embedding_dim, transform, metadata = get_model()
|
| 49 |
if input_img.shape[0] != 3:
|
|
|
|
| 69 |
return attn_img, fig
|
| 70 |
|
| 71 |
input_img = gr.Image(shape=(250,250))
|
| 72 |
+
input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
|
| 73 |
output_img = gr.Image(shape=(250,250))
|
| 74 |
output_plot = gr.Plot()
|
| 75 |
|