Update app.py
Browse files
app.py
CHANGED
|
@@ -56,6 +56,7 @@ def inference(video_file, dataset_type, mask_ratio):
|
|
| 56 |
frames[None, ...], bool_masked_pos_tf, outputs_pt
|
| 57 |
)
|
| 58 |
|
|
|
|
| 59 |
input_frame = denormalize(frames)
|
| 60 |
input_mask = denormalize(mask[0] * frames)
|
| 61 |
output_frame = denormalize(reconstruct_output)
|
|
@@ -81,11 +82,10 @@ def main():
|
|
| 81 |
'./TFVideoMAE_S_K400_16x224_PT'
|
| 82 |
],
|
| 83 |
'UCF' : [
|
| 84 |
-
'
|
| 85 |
'./TFVideoMAE_S_K400_16x224_PT'
|
| 86 |
]
|
| 87 |
}
|
| 88 |
-
|
| 89 |
BENCHMARK_DATASETS = ['K400', 'SSv2', 'UCF']
|
| 90 |
SAMPLE_EXAMPLES = [
|
| 91 |
["examples/k400.mp4", 'Kintetics-400'],
|
|
@@ -103,11 +103,11 @@ def main():
|
|
| 103 |
default=BENCHMARK_DATASETS[0],
|
| 104 |
label='Dataset',
|
| 105 |
),
|
| 106 |
-
gr.
|
| 107 |
-
0
|
| 108 |
-
1
|
| 109 |
-
step=0.
|
| 110 |
-
default=0.
|
| 111 |
label='Mask Ratio'
|
| 112 |
)
|
| 113 |
],
|
|
|
|
| 56 |
frames[None, ...], bool_masked_pos_tf, outputs_pt
|
| 57 |
)
|
| 58 |
|
| 59 |
+
# post process
|
| 60 |
input_frame = denormalize(frames)
|
| 61 |
input_mask = denormalize(mask[0] * frames)
|
| 62 |
output_frame = denormalize(reconstruct_output)
|
|
|
|
| 82 |
'./TFVideoMAE_S_K400_16x224_PT'
|
| 83 |
],
|
| 84 |
'UCF' : [
|
| 85 |
+
'./TFVideoMAE_S_K400_16x224_FT',
|
| 86 |
'./TFVideoMAE_S_K400_16x224_PT'
|
| 87 |
]
|
| 88 |
}
|
|
|
|
| 89 |
BENCHMARK_DATASETS = ['K400', 'SSv2', 'UCF']
|
| 90 |
SAMPLE_EXAMPLES = [
|
| 91 |
["examples/k400.mp4", 'Kintetics-400'],
|
|
|
|
| 103 |
default=BENCHMARK_DATASETS[0],
|
| 104 |
label='Dataset',
|
| 105 |
),
|
| 106 |
+
gr.Slider(
|
| 107 |
+
0,
|
| 108 |
+
1,
|
| 109 |
+
step=0.05,
|
| 110 |
+
default=0.5,
|
| 111 |
label='Mask Ratio'
|
| 112 |
)
|
| 113 |
],
|