Spaces:
Runtime error
Runtime error
| import subprocess | |
| import os | |
| if os.getenv('SYSTEM') == 'spaces': | |
| subprocess.call('pip install gradio==4.29.0'.split()) | |
| subprocess.call('pip install -U openmim'.split()) | |
| subprocess.call('pip install python-dotenv'.split()) | |
| subprocess.call('pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113'.split()) | |
| subprocess.call('mim install mmcv>=2.0.0'.split()) | |
| subprocess.call('mim install mmengine==0.7.2'.split()) | |
| subprocess.call('mim install mmdet==3.0.0'.split()) | |
| subprocess.call('pip install opencv-python'.split()) | |
| subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split()) | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import cv2 | |
| import dotenv | |
| dotenv.load_dotenv() | |
| import numpy as np | |
| import gradio as gr | |
| import glob | |
| from inference import inference_frame,inference_frame_serial | |
| from inference import inference_frame_par_ready | |
| from inference import process_frame | |
| from inference import classes | |
| from inference import class_sizes_lower | |
| from metrics import process_results_for_plot | |
| from metrics import prediction_dashboard | |
| import os | |
| import pathlib | |
| import multiprocessing as mp | |
| from time import time | |
| if not os.path.exists('videos_example'): | |
| REPO_ID='SharkSpace/videos_examples' | |
| snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example') | |
| theme = gr.themes.Soft( | |
| primary_hue="sky", | |
| neutral_hue="slate", | |
| ) | |
| def add_border(frame, color = (255, 0, 0), thickness = 2): | |
| # Add a red border to the image | |
| relative = max(frame.shape[0],frame.shape[1]) | |
| top = int(relative*0.025) | |
| bottom = int(relative*0.025) | |
| left = int(relative*0.025) | |
| right = int(relative*0.025) | |
| # Add the border to the image | |
| bordered_image = cv2.copyMakeBorder(frame, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) | |
| return bordered_image | |
| def overlay_text_on_image(image, text_list, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=0.5, font_thickness=1, margin=10, color=(255, 255, 255), box_color=(0,0,0)): | |
| relative = min(image.shape[0], image.shape[1]) | |
| y0, dy = margin, int(relative*0.1) # start y position and line gap | |
| for i, line in enumerate(text_list): | |
| y = y0 + i * dy | |
| if 'Shark' in line or 'Human' in line: | |
| current_font_size = font_size * 1.2 | |
| text_width, text_height = cv2.getTextSize(line, font, current_font_size, font_thickness)[0] | |
| cv2.rectangle(image, (image.shape[1] - text_width - margin - 5, y - text_height), (image.shape[1] - margin + 5, y + 5), box_color, -1) | |
| cv2.putText(image, line, (image.shape[1] - text_width - margin, y), font, current_font_size, color, font_thickness, lineType=cv2.LINE_AA) | |
| else: | |
| current_font_size = font_size | |
| text_width, text_height = cv2.getTextSize(line, font, current_font_size, font_thickness)[0] | |
| cv2.rectangle(image, (image.shape[1] - text_width - margin - 5, y - text_height), (image.shape[1] - margin + 5, y + 5), box_color, -1) | |
| cv2.putText(image, line, (image.shape[1] - text_width - margin, y), font, current_font_size, color, font_thickness, lineType=cv2.LINE_AA) | |
| return image | |
| def overlay_logo(frame,logo, position=(10, 10)): | |
| """ | |
| Overlay a transparent logo (with alpha channel) on a frame. | |
| Parameters: | |
| - frame: The main image/frame to overlay the logo on. | |
| - logo_path: Path to the logo image. | |
| - position: (x, y) tuple indicating where the logo starts (top left corner). | |
| """ | |
| # Load the logo and its alpha channel | |
| alpha_channel = np.ones(logo.shape[:2], dtype=logo.dtype) | |
| print(logo.min(),logo.max()) | |
| logo = np.dstack((logo, alpha_channel)) | |
| indexes = logo[:,:,1]>150 | |
| logo[indexes,3] = 0 | |
| l_channels = cv2.split(logo) | |
| if len(l_channels) != 4: | |
| raise ValueError("Logo doesn't have an alpha channel!") | |
| l_b, l_g, l_r, l_alpha = l_channels | |
| cv2.imwrite('l_alpha.png',l_alpha*255) | |
| # Extract regions of interest (ROI) from both images | |
| roi = frame[position[1]:position[1]+logo.shape[0], position[0]:position[0]+logo.shape[1]] | |
| # Blend the logo using the alpha channel | |
| for channel in range(0, 3): | |
| roi[:, :, channel] = (l_alpha ) * l_channels[channel] + (1.0 - l_alpha ) * roi[:, :, channel] | |
| return frame | |
| def add_danger_symbol_from_image(frame, top_pred): | |
| relative = max(frame.shape[0],frame.shape[1]) | |
| if top_pred['shark_sighted'] and top_pred['dangerous_dist']: | |
| # Add the danger symbol | |
| danger_symbol = cv2.imread('static/danger_symbol.jpeg') | |
| danger_symbol = cv2.resize(danger_symbol, (int(relative*0.1), int(relative*0.1)), interpolation = cv2.INTER_AREA)[:,:,::-1] | |
| frame = overlay_logo(frame,danger_symbol, position=(int(relative*0.05), int(relative*0.05))) | |
| return frame | |
| def draw_cockpit(frame, top_pred,cnt): | |
| # Bullet points: | |
| high_danger_color = (255,0,0) | |
| low_danger_color = yellowgreen = (154,205,50) | |
| if top_pred['shark_sighted'] > 0: | |
| shark_suspected = 'Shark Sighted !' | |
| elif top_pred['shark_suspected'] > 0: | |
| shark_suspected = 'Shark Suspected !' | |
| else: | |
| shark_suspected = 'No Sharks ...' | |
| if top_pred['human_sighted'] > 0: | |
| human_suspected = 'Human Sighted !' | |
| elif top_pred['human_suspected'] > 0: | |
| human_suspected = 'Human Suspected !' | |
| else: | |
| human_suspected = 'No Humans ...' | |
| shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size']) if top_pred['biggest_shark_size'] else 'Biggest shark size: ...' | |
| shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight']) if top_pred['biggest_shark_weight'] else 'Biggest shark weight: ...' | |
| danger_level = 'Danger Level: ' | |
| danger_level += 'High' if top_pred['dangerous_dist_confirmed'] else 'Low' | |
| danger_color = 'orangered' if top_pred['dangerous_dist_confirmed'] else 'yellowgreen' | |
| # Create a list of strings to plot | |
| strings = [shark_suspected, human_suspected, shark_size_estimate, danger_level] | |
| # shark_sighted = 'Shark Detected: ' + str(top_pred['shark_sighted']) | |
| # human_sighted = 'Number of Humans: ' + str(top_pred['human_n']) | |
| # shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size']) | |
| # shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight']) | |
| # danger_level = 'Danger Level: ' | |
| # danger_level += 'High' if top_pred['dangerous_dist'] else 'Low' | |
| # danger_color = 'orangered' if top_pred['dangerous_dist'] else 'yellowgreen' | |
| # # Create a list of strings to plot | |
| # strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level] | |
| relative = max(frame.shape[0],frame.shape[1]) | |
| if top_pred['shark_sighted'] and top_pred['dangerous_dist_confirmed'] and cnt%2 == 0: | |
| #frame = add_border(frame, color=high_danger_color, thickness=int(relative*0.025)) | |
| frame = add_danger_symbol_from_image(frame, top_pred) | |
| elif top_pred['shark_sighted'] and not top_pred['dangerous_dist_confirmed'] and cnt%2 == 0: | |
| #frame = add_border(frame, color=low_danger_color, thickness=int(relative*0.025)) | |
| frame = add_danger_symbol_from_image(frame, top_pred) | |
| else: | |
| frame = add_border(frame, color=(0,0,0), thickness=int(relative*0.025)) | |
| overlay_text_on_image(frame, strings, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=relative*0.0007, font_thickness=1, margin=int(relative*0.05), color=(255, 255, 255)) | |
| return frame | |
| def process_video(input_video, out_fps = 'auto', skip_frames = 12): | |
| print('Processing video: ') | |
| try: | |
| cap = cv2.VideoCapture(input_video.name) | |
| except: | |
| cap = cv2.VideoCapture(input_video) | |
| output_path = "output.mp4" | |
| if out_fps != 'auto' and type(out_fps) == int: | |
| fps = int(out_fps) | |
| else: | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| if out_fps == 'auto': | |
| fps = int(fps / skip_frames) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| if width > 2200 or height > 2000: | |
| width = int(width//4) | |
| height = int(height//4) | |
| video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)) | |
| iterating, frame = cap.read() | |
| cnt = 0 | |
| drawn_count = 0 | |
| last_5_shark_detected = np.array([0, 0, 0, 0, 0]) | |
| last_5_human_detected = np.array([0, 0, 0, 0, 0]) | |
| last_5_dangerous_dist = np.array([0, 0, 0, 0, 0]) | |
| while iterating: | |
| print('overall count ', cnt) | |
| if (cnt % skip_frames) == 0: | |
| drawn_count += 1 | |
| frame = cv2.resize(frame, (int(width), int(height))) | |
| print('starting Frame: ', cnt) | |
| # flip frame vertically | |
| display_frame, result = inference_frame_serial(frame) | |
| #print(result) | |
| top_pred = process_results_for_plot(predictions = result.numpy(), | |
| classes = classes, | |
| class_sizes = class_sizes_lower) | |
| # add to last 5 | |
| last_5_shark_detected[drawn_count % 5] = int(top_pred['shark_n'] > 0) | |
| last_5_human_detected[drawn_count % 5] = int(top_pred['human_n'] > 0) | |
| last_5_dangerous_dist[drawn_count % 5] = int(top_pred['dangerous_dist'] > 0) | |
| top_pred['shark_sighted'] = int(np.sum(last_5_shark_detected) > 2) | |
| top_pred['human_sighted'] = int(np.sum(last_5_human_detected) > 2) | |
| top_pred['dangerous_dist_confirmed'] = int(np.sum(last_5_dangerous_dist) > 2) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| prediction_frame = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB) | |
| # | |
| #video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| if cnt*skip_frames %2==0: | |
| prediction_frame = cv2.resize(prediction_frame, (int(width), int(height))) | |
| frame = prediction_frame | |
| #if top_pred['shark_sighted'] or top_pred['shark_suspected']: | |
| frame = draw_cockpit(frame, top_pred,cnt*skip_frames) | |
| frame = cv2.resize(frame, (int(width), int(height))) | |
| video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| pred_dashbord = prediction_dashboard(top_pred = top_pred) | |
| drawn_count += 1 | |
| #print('sending frame') | |
| print('finalizing frame:',cnt) | |
| #print(pred_dashbord.shape) | |
| #print(frame.shape) | |
| #print(prediction_frame.shape) | |
| #print(width, height) | |
| yield frame | |
| cnt += 1 | |
| iterating, frame = cap.read() | |
| video.release() | |
| yield None | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Markdown("Alpha Demo of the Sharkpatrol Oceanlife Detector.") | |
| with gr.Row(): | |
| input_video = gr.File(label="Input",height=50) | |
| #output_video = gr.File(label="Output Video",height=50) | |
| #.style(equal_height=True,height='25%'): | |
| original_frames = gr.Image(label="Processed Frame") #.style( height=650) | |
| #processed_frames = gr.Image(label="Shark Engine") | |
| #dashboard = gr.Image(label="Events") | |
| with gr.Row(): | |
| paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4')) | |
| samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)] | |
| examples = gr.Examples(samples, inputs=input_video) | |
| process_video_btn = gr.Button("Process Video") | |
| #process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video, dashboard]) | |
| process_video_btn.click(process_video, input_video, [ original_frames]) | |
| demo.queue() | |
| if os.getenv('SYSTEM') == 'spaces': | |
| demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD'))) | |
| else: | |
| demo.launch(debug=True,share=True) | |