Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import zipfile | |
| import imageio | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from utils import read_video, frame_sampling | |
| from utils import num_frames, patch_size, input_size | |
| from labels import K400_label_map | |
| LABEL_MAPS = { | |
| 'K400': K400_label_map, | |
| } | |
| ALL_MODELS = [ | |
| 'TFVideoFocalNetB_K400_8x224', | |
| ] | |
| sample_example = [ | |
| ["examples/k400.mp4", ALL_MODELS[0]], | |
| ] | |
| def get_model(model_type): | |
| model_path = keras.utils.get_file( | |
| origin=f'https://github.com/innat/Video-FocalNets/releases/download/v1.1/{model_type}.zip', | |
| ) | |
| with zipfile.ZipFile(model_path, 'r') as zip_ref: | |
| zip_ref.extractall('./') | |
| model = keras.models.load_model(model_type) | |
| label_map = LABEL_MAPS.get('K400') | |
| label_map = {v: k for k, v in label_map.items()} | |
| return model, label_map | |
| def inference(video_file, model_type): | |
| # get sample data | |
| container = read_video(video_file) | |
| frames = frame_sampling(container, num_frames=num_frames) | |
| # get models | |
| model, label_map = get_model(model_type) | |
| model.trainable = False | |
| # inference on model | |
| outputs = model(frames[None, ...], training=False) | |
| probabilities = tf.nn.softmax(outputs).numpy().squeeze(0) | |
| confidences = { | |
| label_map[i]: float(probabilities[i]) for i in np.argsort(probabilities)[::-1] | |
| } | |
| return confidences | |
| def main(): | |
| iface = gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.Video(type="file", label="Input Video"), | |
| gr.Dropdown( | |
| choices=ALL_MODELS, | |
| label="Model" | |
| ) | |
| ], | |
| outputs=gr.Label(num_top_classes=3, label='scores'), | |
| examples=sample_example, | |
| title="Video-FocalNets: Spatio-Temporal Focal Modulation.", | |
| description="Keras reimplementation of <a href='https://github.com/innat/Video-FocalNets'>Video-FocalNets</a> is presented here." | |
| ) | |
| iface.launch() | |
| if __name__ == '__main__': | |
| main() |