Spaces:
Build error
Build error
| from webbrowser import get | |
| import gradio as gr | |
| import os | |
| import os.path as op | |
| import json | |
| from typing import Dict, List, Tuple | |
| import torch | |
| import numpy as np | |
| import datetime | |
| import logging | |
| import warnings | |
| # imports for gps & map | |
| import osmnx as ox | |
| import folium | |
| import pandas as pd | |
| import geopandas | |
| import codecs | |
| from tracking.gps import * | |
| # imports for tracking | |
| from plasticorigins.tools.files import download_from_url, create_unique_folder, load_trash_icons | |
| from plasticorigins.tools.video_readers import IterableFrameReader, SimpleVideoReader | |
| from plasticorigins.detection.centernet.networks.mobilenet import get_mobilenet_v3_small | |
| from plasticorigins.detection.yolo import load_model, predict_yolo | |
| from plasticorigins.detection.detect import detect | |
| from plasticorigins.tracking.postprocess_and_count_tracks import filter_tracks, postprocess_for_api, count_objects | |
| from plasticorigins.tracking.utils import get_detections_for_video, write_tracking_results_to_file, read_tracking_results, generate_video_with_annotations | |
| from plasticorigins.tracking.track_video import track_video | |
| from plasticorigins.tracking.trackers import get_tracker | |
| demo = gr.Blocks() | |
| title = "Surfnet AI Demo" | |
| with demo: | |
| with gr.Tabs(): | |
| with gr.TabItem("The Project"): | |
| gr.HTML(""" <!DOCTYPE html> | |
| <html lang="en-us"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no"> | |
| <title>Surfnet AI Demo</title> | |
| </head> | |
| <body> | |
| <h1 style="text-align: center;"> <b>Welcome to the Surfnet demo, the algo that tracks Plastic Pollution</b> π</h1> | |
| <img src="https://play-lh.googleusercontent.com/3SCReOTmp19uyohvMqaaJOtzkR8DnPLk-OqL8nGpTj9Ilu6-oSS9no9jeR3fDVh0dYo" style="width:10%; margin-left:661px; margin-right:661px; margin-top:20px; margin-bottom:20px"/> | |
| <p style="text-align: center;"> We all dream about swimming in clear blue waters and walking | |
| bare-footed on a beautiful white sand beach. But our dream is | |
| threatened. Plastics are invading every corner of the earth, | |
| from remote alpine lakes to the deepest oceanic trench. | |
| Thankfully, there are many things we can do.π€ | |
| <b>Plastic Origins</b>, a citizen science project from <a href="https://surfrider.eu" target = "_blank"><b><u>Surfrider Europe</u></b></a>, | |
| using artificial intelligence to map river plastic pollution, is one of them. | |
| This demo is here for you to test the AI model we use to detect and count litter items on riverbanks. | |
| </p> | |
| <br> | |
| <p style="text-align: center;"> | |
| β Read more on <a href="https://plasticorigins.eu" target = "_blank"> <b><u>www.plasticorigins.eu</u></b></a> | |
| <br> | |
| π» Join the dev team on <a href="https://github.com/surfriderfoundationeurope/The-Plastic-Origins-Project" target = "_blank"> <b><u>Github</u></b></a> | |
| <br> | |
| π·οΈ Help us label images on <a href="https://www.trashroulette.com/#/" target = "_blank"> <b><u>www.trashroulette.com</u></b></a> | |
| <br> | |
| <br> | |
| <p style="text-align: center"> | |
| π§ contact : | |
| <br> | |
| <a href="mailto:plasticorigins@surfrider.eu"> <b><u>plasticorigins@surfrider.eu</u></b></a> | |
| </p> | |
| </div> | |
| </body> | |
| </html>""") | |
| with gr.TabItem("Surfnet AI"): | |
| gr.HTML(""" <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <title>Left Side Panel</title> | |
| <meta charset="utf-8"> | |
| </head> | |
| <body> | |
| <p style="text-align: center;"> | |
| <b>Surfnet</b> is an AI model that detects trash on riverbanks. | |
| We use it to map river plastic pollution and act to reduce the introduction of litter into the environment. | |
| Developed & Maintain by a bunch of amazing volunteers from the NGO Surfrider Foundation Europe. | |
| </p> | |
| </body> | |
| </html>""" ) | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.DEBUG) | |
| ch = logging.StreamHandler() | |
| ch.setLevel(logging.DEBUG) | |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| ch.setFormatter(formatter) | |
| logger.addHandler(ch) | |
| class DotDict(dict): | |
| """dot.notation access to dictionary attributes""" | |
| __getattr__ = dict.get | |
| __setattr__ = dict.__setitem__ | |
| __delattr__ = dict.__delitem__ | |
| id_categories = { | |
| 0: 'Fragment', #'Sheet / tarp / plastic bag / fragment', | |
| 1: 'Insulating', #'Insulating material', | |
| 2: 'Bottle', #'Bottle-shaped', | |
| 3: 'Can', #'Can-shaped', | |
| 4: 'Drum', | |
| 5: 'Packaging', #'Other packaging', | |
| 6: 'Tire', | |
| 7: 'Fishing net', #'Fishing net / cord', | |
| 8: 'Easily namable', | |
| 9: 'Unclear' | |
| } | |
| config_track = DotDict({ | |
| "yolo_conf_thrld": 0.35, | |
| "yolo_iou_thrld": 0.5, | |
| "confidence_threshold": 0.004, # for the tracking part | |
| "detection_threshold": 0.3, # for centernet | |
| "downsampling_factor": 4, | |
| "noise_covariances_path": "data/tracking_parameters", | |
| "output_shape": (960,544), | |
| "size": 768, | |
| "skip_frames": 3, #3 | |
| "arch": "mobilenet_v3_small", | |
| "device": "cpu", | |
| "detection_batch_size": 1, | |
| "display": 0, | |
| "kappa": 4, #4 | |
| "tau": 3, #4 | |
| "max_length": 240, | |
| "downscale_output":2 | |
| }) | |
| logger.info('---Yolo model...') | |
| # Yolo has warning problems, so we set an env variable to remove it | |
| os.environ["VERBOSE"] = "False" | |
| URL_MODEL = "https://github.com/surfriderfoundationeurope/IA_Pau/releases/download/v0.1/yolov5.pt" | |
| FILE_MODEL = "yolov5.pt" | |
| model_path = download_from_url(URL_MODEL, FILE_MODEL, "./models", logger) | |
| model_yolo = load_model(model_path, config_track.device, config_track.yolo_conf_thrld, config_track.yolo_iou_thrld) | |
| logger.info('---Centernet model...') | |
| URL_MODEL = "https://partage.imt.fr/index.php/s/sJi22N6gedN6T4q/download" | |
| FILE_MODEL = "mobilenet_v3_pretrained.pth" | |
| model_path = download_from_url(URL_MODEL, FILE_MODEL, "./models", logger) | |
| model = get_mobilenet_v3_small(num_layers=0, heads={'hm': 1}, head_conv=256) | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| model.load_state_dict(checkpoint['model'], strict=True) | |
| URL_DEMO1 = "https://etlplasticostorageacc.blob.core.windows.net/surfnetbenchmark/video_niv15.mp4" | |
| FILE_DEMO1 = "video_niv15.mp4" | |
| download_from_url(URL_DEMO1, FILE_DEMO1, "./data/", logger) | |
| video1_path = op.join("./data", FILE_DEMO1) | |
| URL_DEMO2 = "https://etlplasticostorageacc.blob.core.windows.net/surfnetbenchmark/video_midouze15.mp4" | |
| FILE_DEMO2 = "video_midouze15.mp4" | |
| download_from_url(URL_DEMO2, FILE_DEMO2, "./data/", logger) | |
| video2_path = op.join("./data", FILE_DEMO2) | |
| URL_DEMO3 = "https://etlplasticostorageacc.blob.core.windows.net/surfnetbenchmark/video_antoine15.mp4" | |
| FILE_DEMO3 = "video_antoine15.mp4" | |
| download_from_url(URL_DEMO3, FILE_DEMO3, "./data/", logger) | |
| video3_path = op.join("./data", FILE_DEMO3) | |
| JSON_FILE_PATH = "data/" | |
| labels2icons = load_trash_icons("./data/icons/") | |
| def track(args): | |
| device = torch.device("cpu") | |
| engine = get_tracker('EKF') | |
| detector = None | |
| # centernet version | |
| if args.model_type == "yolo": | |
| logger.info("---Using Yolo") | |
| detector = lambda frame: predict_yolo(model_yolo, frame, size=config_track.size, augment=False) | |
| elif args.model_type == "centernet": | |
| logger.info("---Using Centernet") | |
| detector = lambda frame: detect(frame, threshold=args.detection_threshold, model=model) | |
| transition_variance = np.load(op.join(args.noise_covariances_path, 'transition_variance.npy')) | |
| observation_variance = np.load(op.join(args.noise_covariances_path, 'observation_variance.npy')) | |
| logger.info(f'---Processing {args.video_path}') | |
| reader = IterableFrameReader(video_filename=args.video_path, | |
| skip_frames=args.skip_frames, | |
| output_shape=args.output_shape, | |
| progress_bar=True, | |
| preload=False, | |
| max_frame=args.max_length) | |
| input_shape = reader.input_shape | |
| output_shape = reader.output_shape | |
| ratio_y = input_shape[0] / (output_shape[0] // args.downsampling_factor) | |
| ratio_x = input_shape[1] / (output_shape[1] // args.downsampling_factor) | |
| detections = [] | |
| logger.info('---Detecting...') | |
| if args.model_type == "yolo": | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore") | |
| for frame in reader: | |
| detections.append(detector(frame)) | |
| elif args.model_type == "centernet": | |
| detections = get_detections_for_video(reader, detector, batch_size=args.detection_batch_size, device=device) | |
| logger.info('---Tracking...') | |
| display = None | |
| results = track_video(reader, iter(detections), args, engine, transition_variance, observation_variance, display, is_yolo=args.model_type=="yolo") | |
| reader.video.release() | |
| # store unfiltered results | |
| datestr = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f') | |
| output_filename = op.splitext(args.video_path)[0] + "_" + datestr + '_unfiltered.txt' | |
| write_tracking_results_to_file(results, ratio_x=ratio_x, ratio_y=ratio_y, output_filename=output_filename) | |
| logger.info('---Filtering...') | |
| # read from the file | |
| results = read_tracking_results(output_filename) | |
| filtered_results = filter_tracks(results, config_track.kappa, config_track.tau) | |
| # store filtered results | |
| output_filename = op.splitext(args.video_path)[0] + "_" + datestr + '_filtered.txt' | |
| write_tracking_results_to_file(filtered_results, ratio_x=ratio_x, ratio_y=ratio_y, output_filename=output_filename) | |
| return filtered_results | |
| def run_model(video_path, model_type, seconds, skip, tau, kappa, gps_file): | |
| logger.info('---video filename: '+ video_path) | |
| # launch the tracking | |
| config_track.video_path = video_path | |
| config_track.model_type = model_type | |
| config_track.skip_frames = int(skip) | |
| config_track.tau = int(tau) | |
| config_track.kappa = int(kappa) | |
| config_track.max_length = int(seconds)*24 | |
| out_folder = create_unique_folder("/tmp/", "output") | |
| output_path = op.join(out_folder, "video.mp4") | |
| filtered_results = track(config_track) | |
| # postprocess | |
| logger.info('---Postprocessing...') | |
| output_json_path = op.join(out_folder, "output.json") | |
| output_json = postprocess_for_api(filtered_results, id_categories) | |
| with open(output_json_path, 'w') as f_out: | |
| json.dump(output_json, f_out) | |
| # build video output | |
| logger.info('---Generating new video...') | |
| reader = IterableFrameReader(video_filename=config_track.video_path, | |
| skip_frames=0, | |
| progress_bar=True, | |
| preload=False, | |
| max_frame=config_track.max_length) | |
| # Get GPS Data | |
| video_duration = reader.total_num_frames / reader.fps | |
| gps_data = get_filled_gps(gps_file, video_duration) | |
| # Generate new video | |
| generate_video_with_annotations(reader, output_json, output_path, | |
| config_track.skip_frames, config_track.max_length, | |
| config_track.downscale_output, logger, gps_data=gps_data, | |
| labels2icons=labels2icons) | |
| output_label = count_objects(output_json, id_categories) | |
| # Get Plastic Map | |
| map_frame = None # default value in case no GPS file | |
| if gps_data is not None: | |
| logger.info('---Creating Plastic Map...') | |
| # Get Trash Prediction | |
| with open(output_json_path) as json_file: | |
| predictions = json.load(json_file) | |
| trash_df = get_df_prediction(predictions, reader.fps) | |
| if len(trash_df) != 0 : | |
| # Get Trash prediction alongside GPS data | |
| trash_gps_df = get_trash_gps_df(trash_df,gps_data) | |
| trash_gps_geo_df = get_trash_gps_geo_df(trash_gps_df) | |
| # Create Map | |
| center_lat = trash_gps_df.iloc[0]['Latitude'] | |
| center_long = trash_gps_df.iloc[0]['Longitude'] | |
| map_path = get_plastic_map(center_lat,center_long,trash_gps_geo_df,out_folder) | |
| html_content = codecs.open(map_path, 'r') | |
| map_html = html_content.read() | |
| map_frame = f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; geolocation; microphone; camera; | |
| display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
| allow-scripts allow-same-origin allow-popups | |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
| allowpaymentrequest="" frameborder="0" srcdoc='{map_html}'></iframe>""" | |
| logger.info('---Surfnet End processing...') | |
| return output_path, map_frame, output_label, output_json_path | |
| def get_filled_gps(file_obj, video_duration)->list: | |
| """Get a filled GPS point list from Plastic Origin mobile GPS JSON track | |
| Args: | |
| file_obj: a file_obj from gradio input File type | |
| video_duration: in seconds | |
| Returns: | |
| gps_data (list): the GPS filled data as a list | |
| """ | |
| if file_obj is not None: | |
| json_data = parse_json(file_obj) | |
| json_data_list = get_json_gps_list(json_data) | |
| gps_data = fill_gps(json_data_list, video_duration) | |
| return gps_data | |
| else: | |
| return None | |
| def get_plastic_map(center_lat,center_long,trash_gps_gdf,out_folder)->str: | |
| """Get the map with plastic trash detection | |
| Args: | |
| center_lat (float): latitude to center map | |
| center_long (float): longitude to center map | |
| trash_gps_gdf (DataFrame): trash & gps geo dataframe | |
| out_folder (str): folder to save html map | |
| Returns: | |
| map_html_path (str): full path to html map | |
| """ | |
| m = folium.Map([center_lat, center_long], zoom_start=16) | |
| locs = zip(trash_gps_gdf.geometry.y,trash_gps_gdf.geometry.x) | |
| labels = list(trash_gps_gdf['label']) | |
| i = 0 | |
| for location in locs: | |
| folium.CircleMarker(location=location).add_child(folium.Popup(labels[i])).add_to(m) | |
| i = i + 1 | |
| map_html_path = op.join(out_folder,"plasticmap.html") | |
| m.save(map_html_path) | |
| return map_html_path | |
| video_in = gr.inputs.Video(type="mp4", source="upload", label="Video Upload", optional=False) | |
| model_type = gr.inputs.Dropdown(choices=["centernet", "yolo"], type="value", default="yolo", label="model") | |
| skip_slider = gr.inputs.Slider(minimum=0, maximum=15, step=1, default=3, label="skip frames") | |
| tau_slider = gr.inputs.Slider(minimum=1, maximum=7, step=1, default=3, label="tau") | |
| kappa_slider = gr.inputs.Slider(minimum=1, maximum=7, step=1, default=4, label="kappa") | |
| seconds_num = gr.inputs.Number(default=10, label="seconds") | |
| gps_in = gr.inputs.File(type="file", label="GPS Upload", optional=True) | |
| gr.Interface(fn=run_model, inputs=[video_in, model_type, seconds_num, skip_slider, tau_slider, kappa_slider,gps_in], | |
| outputs=["playable_video","html","label", "file"], | |
| examples=[[video1_path, "yolo", 10, 3, 3, 4, JSON_FILE_PATH+"gavepau.json"], | |
| [video2_path, "yolo", 10, 3, 3, 4, JSON_FILE_PATH+"midouze.json"], | |
| [video3_path, "yolo", 10, 3, 3, 4, JSON_FILE_PATH+"gavepau.json"]], | |
| description="Upload a video, optionnaly a GPS file and you'll get Plastic detection on river.", | |
| theme="huggingface", | |
| allow_screenshot=False, allow_flagging="never") | |
| demo.launch() |