Spaces:
Build error
Build error
| import json | |
| import gradio as gr | |
| from time import time | |
| import onnxruntime as ort | |
| from mediapipe.python.solutions import holistic | |
| from torchvision.transforms.v2 import Compose, Lambda, Normalize | |
| from utils import get_predictions, preprocess | |
| title = ''' | |
| ''' | |
| cite_markdown = ''' | |
| ''' | |
| description = ''' | |
| ''' | |
| examples = [ | |
| ['000_con_cho.mp4'], | |
| ] | |
| ort_session = ort.InferenceSession('videomae_skeleton_v2.3.onnx') | |
| model_config = json.load(open('config.json')) | |
| preprocessor_config = json.load(open('preprocessor_config.json')) | |
| mean = preprocessor_config['image_mean'] | |
| std = preprocessor_config['image_std'] | |
| if 'shortest_edge' in preprocessor_config['size']: | |
| model_input_height = model_input_width = preprocessor_config['size']['shortest_edge'] | |
| else: | |
| model_input_height = preprocessor_config['size']['height'] | |
| model_input_width = preprocessor_config['size']['width'] | |
| # Define the transform. | |
| transform = Compose( | |
| [ | |
| Lambda(lambda x: x / 255.0), | |
| Normalize(mean=mean, std=std), | |
| ] | |
| ) | |
| def inference( | |
| video: str, | |
| progress: gr.Progress = gr.Progress(), | |
| ) -> str: | |
| ''' | |
| Video-based inference for Vietnamese Sign Language recognition. | |
| Parameters | |
| ---------- | |
| video : str | |
| The path to the video. | |
| progress : gr.Progress, optional | |
| The progress bar, by default gr.Progress() | |
| Returns | |
| ------- | |
| str | |
| The top-3 predictions. | |
| ''' | |
| progress(0, desc='Preprocessing video') | |
| keypoints_detector = holistic.Holistic( | |
| static_image_mode=False, | |
| model_complexity=2, | |
| enable_segmentation=True, | |
| refine_face_landmarks=True, | |
| ) | |
| start_time = time() | |
| inputs = preprocess( | |
| model_num_frames=model_config['num_frames'], | |
| keypoints_detector=keypoints_detector, | |
| source=video, | |
| model_input_height=model_input_height, | |
| model_input_width=model_input_width, | |
| transform=transform, | |
| ) | |
| end_time = time() | |
| data_time = end_time - start_time | |
| progress(1/2, desc='Getting predictions') | |
| start_time = time() | |
| predictions = get_predictions( | |
| inputs=inputs, | |
| ort_session=ort_session, | |
| id2gloss=model_config['id2label'], | |
| k=3, | |
| ) | |
| end_time = time() | |
| model_time = end_time - start_time | |
| if len(predictions) == 0: | |
| output_message = 'No sign language detected in the video. Please try again.' | |
| else: | |
| output_message = 'The top-3 predictions are:\n' | |
| for i, prediction in enumerate(predictions): | |
| output_message += f'\t{i+1}. {prediction["label"]} ({prediction["score"]:2f})\n' | |
| output_message += f'Data processing time: {data_time:.2f} seconds\n' | |
| output_message += f'Model inference time: {model_time:.2f} seconds\n' | |
| output_message += f'Total time: {data_time + model_time:.2f} seconds' | |
| progress(1/2, desc='Completed') | |
| return output_message | |
| iface = gr.Interface( | |
| fn=inference, | |
| inputs='video', | |
| outputs='text', | |
| examples=examples, | |
| title=title, | |
| description=description, | |
| ) | |
| iface.launch() | |
| # print(inference('000_con_cho.mp4')) | |