| | import os |
| | import io |
| | import pickle |
| | import cv2 |
| | import gradio as gr |
| | print(gr.__version__) |
| | from tempSegAndAllErrorsForAllFrames import getAllErrorsAndSegmentation |
| | from models.detectron2.platform_detector_setup import get_platform_detector |
| | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation |
| | from models.detectron2.diver_detector_setup import get_diver_detector |
| | from models.pose_estimator.pose_estimator_model_setup import get_pose_model |
| | from models.detectron2.splash_detector_setup import get_splash_detector |
| | from scoring_functions import * |
| | from generate_reports import * |
| | from tempSegAndAllErrorsForAllFrames_newVids import getAllErrorsAndSegmentation_newVids, abstractSymbols |
| |
|
| | from jinja2 import Environment, FileSystemLoader |
| | from PIL import Image, ImageDraw |
| | from io import BytesIO |
| | import base64 |
| |
|
| | |
| | |
| | |
| | |
| | template_path = 'report_template_tables.html' |
| | dive_data = {} |
| |
|
| | class CPU_Unpickler(pickle.Unpickler): |
| | def find_class(self, module, name): |
| | if module == 'torch.storage' and name == '_load_from_bytes': |
| | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') |
| | else: return super().find_class(module, name) |
| |
|
| | dive_data_precomputed = CPU_Unpickler(open('./segmentation_error_data.pkl', 'rb')).load() |
| |
|
| | |
| | |
| |
|
| | import sys |
| | import csv |
| |
|
| | csv.field_size_limit(sys.maxsize) |
| |
|
| | with open('FineDiving/fine-grained_annotation_aqa.pkl', 'rb') as f: |
| | dive_annotation_data = pickle.load(f) |
| |
|
| | def extract_frames(video_path): |
| | cap = cv2.VideoCapture(video_path) |
| | |
| | if not cap.isOpened(): |
| | print("Error: Couldn't open video file.") |
| | exit() |
| | |
| | frame_skip = 1 |
| | |
| | frame_count = 0 |
| | frames = [] |
| | i = 0 |
| | while True: |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| | if i > frame_skip - 1: |
| | frame_count += 1 |
| | |
| | |
| | frame = cv2.resize(frame, (455, 256)) |
| | frames.append(frame) |
| | i = 0 |
| | continue |
| | |
| | i += 1 |
| | cap.release() |
| | print("frame_count", frame_count) |
| | return frames |
| |
|
| | def get_key_from_videopath(video): |
| | try: |
| | video_name = video.split('/')[-1] |
| | first_folder = video_name.split('_')[1] |
| | second_folder = video_name.split('_')[2].split('.')[0] |
| | return (first_folder, int(second_folder)) |
| | except: |
| | return None |
| |
|
| | def get_abstracted_symbols_precomputed(video, key, progress=gr.Progress()): |
| | progress(0, desc="Abstracting Symbols") |
| | if video is None: |
| | raise gr.Error("input a video!!") |
| | local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}/".format(key[0], key[1]) |
| | directory = "file:///Users/lokamoto/Comprehensive_AQA/FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1]) |
| | |
| | |
| | global dive_data_precomputed |
| | dive_data = dive_data_precomputed[key] |
| | html_intermediate = generate_symbols_report_precomputed("intermediate_steps.html", dive_data, local_directory, progress=progress) |
| | progress(0.95, desc="Abstracting Symbols") |
| | return html_intermediate |
| |
|
| | def get_abstracted_symbols_calculated(video, progress=gr.Progress()): |
| | progress(0, desc="Abstracting Symbols") |
| | frames = extract_frames(video) |
| | global dive_data |
| | dive_data = abstractSymbols(frames, progress=progress, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
| | dive_data['frames'] = frames |
| | html_intermediate = generate_symbols_report("intermediate_steps.html", dive_data, frames) |
| | return html_intermediate |
| |
|
| | def get_abstracted_symbols(video, progress=gr.Progress()): |
| | if video is None: |
| | raise gr.Error("Click on an example diving video first!") |
| | key = get_key_from_videopath(video) |
| | if key is None: |
| | return get_abstracted_symbols_calculated(video, progress=progress) |
| | else: |
| | return get_abstracted_symbols_precomputed(video, key, progress=progress) |
| |
|
| | def get_score_report_precomputed(video, key, progress=gr.Progress(), diveNum=""): |
| | progress(0, desc="Calculating Dive Errors") |
| | if video is None: |
| | raise gr.Error("input a video!!") |
| | global dive_data_precomputed |
| | dive_data = dive_data_precomputed[key] |
| | local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}/".format(key[0], key[1]) |
| | directory = "file:///Users/lokamoto/Comprehensive_AQA/FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1]) |
| |
|
| | intermediate_scores_dict = get_all_report_scores(dive_data) |
| | progress(0.75, desc="Generating Score Report") |
| | print('getting html...') |
| | html = generate_report(template_path, intermediate_scores_dict, directory, local_directory, progress=progress) |
| | progress(0.9, desc="Generating Score Report") |
| | html = ( |
| | "<div style='max-width:100%; max-height:360px; overflow:auto'>" |
| | + html |
| | + "</div>") |
| | print("returning...") |
| | return html |
| |
|
| | def get_score_report_calculated(video, progress=gr.Progress(), diveNum=""): |
| | progress(0, desc="Calculating Dive Errors") |
| | global dive_data |
| | frames = extract_frames(video) |
| | dive_data = getAllErrorsAndSegmentation_newVids(frames, dive_data, progress=progress, diveNum=diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
| | intermediate_scores_dict = get_all_report_scores(dive_data) |
| | progress(0.75, desc="Generating Score Report") |
| | print('getting html...') |
| | html = generate_report_from_frames(template_path, intermediate_scores_dict, frames) |
| | html = ( |
| | "<div style='max-width:100%; max-height:360px; overflow:auto'>" |
| | + html |
| | + "</div>") |
| | print("returning...") |
| | progress(8/8, desc="Generating Score Report") |
| | return html |
| |
|
| | def get_score_report(video, progress=gr.Progress(), diveNum=""): |
| | if video is None: |
| | raise gr.Error("input a video!!") |
| | key = get_key_from_videopath(video) |
| | if key is None: |
| | return get_score_report_calculated(video, progress=progress) |
| | else: |
| | return get_score_report_precomputed(video, key, progress=progress) |
| |
|
| |
|
| | def get_html_from_video(video, diveNum=""): |
| | if video is None: |
| | raise gr.Error("input a video!!") |
| | frames = extract_frames(video) |
| | dive_data = abstractSymbols(frames, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
| | dive_data['frames'] = frames.copy() |
| | html_intermediate = generate_symbols_report("intermediate_steps.html", dive_data, frames) |
| | yield html_intermediate |
| | dive_data = getAllErrorsAndSegmentation_newVids(frames, dive_data, diveNum=diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
| | intermediate_scores_dict = get_all_report_scores(dive_data) |
| | print('getting html...') |
| | html = generate_report_from_frames(template_path, intermediate_scores_dict, frames) |
| | html = ( |
| | "<div style='max-width:100%; max-height:360px; overflow:auto'>" |
| | + html_intermediate |
| | + html |
| | + "</div>") |
| | print("returning...") |
| | yield html |
| |
|
| | def get_html_from_finedivingkey(first_folder, second_folder): |
| | board_side = "left" |
| | key = (first_folder, int(second_folder)) |
| | local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1]) |
| | directory = "file:///Users/lokamoto/Comprehensive_AQA/FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1]) |
| | print("key:", key) |
| | diveNum = dive_annotation_data[key][0] |
| | pose_preds, takeoff, twist, som, entry, distance_from_board, position_tightness, feet_apart, over_under_rotation, splash, above_boards, on_boards, som_counts, twist_counts, board_end_coords, diver_boxes = getAllErrorsAndSegmentation(first_folder, second_folder, diveNum, board_side=board_side, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
| | dive_data['pose_pred'] = pose_preds |
| | dive_data['takeoff'] = takeoff |
| | dive_data['twist'] = twist |
| | dive_data['som'] = som |
| | dive_data['entry'] = entry |
| | dive_data['distance_from_board'] = distance_from_board |
| | dive_data['position_tightness'] = position_tightness |
| | dive_data['feet_apart'] = feet_apart |
| | dive_data['over_under_rotation'] = over_under_rotation |
| | dive_data['splash'] = splash |
| | dive_data['above_boards'] = above_boards |
| | dive_data['on_boards'] = on_boards |
| | dive_data['som_counts'] = som_counts |
| | dive_data['twist_counts'] = twist_counts |
| | dive_data['board_end_coords'] = board_end_coords |
| | dive_data['diver_boxes'] = diver_boxes |
| | dive_data['diveNum'] = diveNum |
| | dive_data['board_side'] = board_side |
| |
|
| | intermediate_scores_dict = get_all_report_scores(dive_data) |
| | html = generate_report(template_path, intermediate_scores_dict, directory, local_directory) |
| | html = ( |
| | "<div style='max-width:100%; max-height:360px; overflow:auto'>" |
| | + html |
| | + "</div>") |
| |
|
| | return html |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def enable_get_score_btn(get_score_btn): |
| | return gr.Button(interactive=True, variant="primary") |
| |
|
| | def disable_get_score_btn(get_score_btn): |
| | return gr.Button(interactive=False, variant="secondary") |
| | |
| | |
| | with gr.Blocks() as demo_precomputed: |
| | gr.Markdown( |
| | """ |
| | # Neuro-Symbolic Olympic Diving Judge |
| | Authors: Lauren Okamoto, Paritosh Parmar \n |
| | This system not only scores an Olympic dive, but outputs a detailed report summarizing each component of the dive and how we evaluated it. We first abstract the necessary symbols, and then proceed to score the dive.\n |
| | See more details on the system by watching this [Video](https://youtu.be/NDtdtguUzjQ). \n |
| | Technical report: [Report](https://arxiv.org/abs/2403.13798). \n |
| | Github: [Code](https://github.com/laurenok24/NSAQA). |
| | """) |
| |
|
| | gr.Markdown( |
| | """ |
| | ## Step 1: Neural Symbol Abstraction |
| | We first abstract the necessary visual elements from the provided diving video. This includes the platform, splash, and the pose estimation of the diver. |
| | """ |
| | ) |
| | |
| | gr.HTML( |
| | """ |
| | <table> |
| | <tr> |
| | <td> |
| | Platform |
| | <img src='file/platform.png' height='90'> |
| | </td> |
| | <td> |
| | The location of the platform is crucial to determine when the diver leaves the platform, thus starting their dive. |
| | It is also important to assess how close the diver comes to its edge, which is relevant to scoring. |
| | </td> |
| | <td> |
| | Pose Estimation of Diver |
| | <img src='file/pose_estimation.png' height='70'> |
| | </td> |
| | <td> |
| | The pose of the diver in the sequence of video frames is critical to understanding and assessing the dive. |
| | We obtain 2D pose data with locations of various body parts to recognize sub-actions being performed by the diver, such as a somersault, a twist, or an entry, and also assess the quality of that sub-action. |
| | </td> |
| | <td> |
| | Splash |
| | <img src='file/splash.png' height='90'> |
| | </td> |
| | <td> |
| | Splash at entry into the pool is a conspicuous visual feature of a dive. |
| | The size of the splash is an important element in traditional scoring of dives. |
| | A large splash mars the end of a dive and also likely indicates a flaw in form at water entry. |
| | </td> |
| | </tr> |
| | </table> |
| | """ |
| | ) |
| | gr.Markdown( |
| | """ |
| | 1. Select one of the example diving videos. |
| | 2. Hit the **Abstract Symbols** button. The symbols abstracted will appear to the right of the diving video. |
| | """ |
| | ) |
| | with gr.Row(variant='panel'): |
| | with gr.Column(): |
| | video = gr.Video(label="Video", format="mp4", include_audio=False, sources=["upload"], interactive=False) |
| | examples = gr.Examples(examples = [['01_10.mp4'], ['01_11.mp4'], ['01_16.mp4'], ['01_33.mp4'], ['01_76.mp4'], ['01_140.mp4']], inputs=[video], label="Click on one of the following diving videos") |
| | symbol_output = gr.HTML(label="Output") |
| | abstract_symbols_btn = gr.Button("Abstract Symbols", variant='secondary') |
| | gr.Markdown( |
| | """ |
| | ## Step 2: Calculate Rules-Based Errors and Generate Detailed Score Report |
| | |
| | Using the abstracted symbols, we calculate different "errors" of the dive. |
| | These errors are: **feet apart; height off board; distance from board; somersault position tightness; knee straightness; twist position straightness; over/under rotation; straightness of body during entry; and splash size.** |
| | Each error is scored on a scale of 0-10, and are then averaged to reach a final score for the dive. |
| | |
| | We then programmatically generate a detailed performance report containing different aspects of the dive, their percentile scores, and visual evidence. |
| | It can be helpful for a number of reasons including as a support to human judges and as an educational tool to teach coaches, athletes, and judges how to score. |
| | |
| | 1. Click the **Generate Score Report** button. The Score Report will be generated below. (Abstract Symbols first if you haven't already!) |
| | """ |
| | ) |
| |
|
| | |
| | get_score_btn = gr.Button("Generate Score Report", interactive=False) |
| | score_report = gr.HTML(label="Report") |
| | |
| | |
| | video.change(fn=disable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn) |
| | video.change(fn=enable_get_score_btn, inputs=abstract_symbols_btn, outputs=abstract_symbols_btn) |
| | abstract_symbols_btn.click(fn=get_abstracted_symbols, inputs=video, outputs=symbol_output).success(fn=enable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn) |
| | symbol_output.change(fn=disable_get_score_btn, inputs=abstract_symbols_btn, outputs=abstract_symbols_btn) |
| | symbol_output.change(fn=enable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn) |
| | get_score_btn.click(fn=get_score_report, inputs=video, outputs=score_report) |
| |
|
| |
|
| | |
| |
|
| |
|
| | demo_precomputed.queue() |
| | demo_precomputed.launch(share=True, allowed_paths=["."]) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|