Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| from super_gradients.training import models | |
| from deep_sort_torch.deep_sort.deep_sort import DeepSort | |
| from super_gradients.training import models | |
| from super_gradients.training.pipelines.pipelines import DetectionPipeline | |
| from model_tools import get_prediction, get_color | |
| import cv2 | |
| import datetime | |
| import torch | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| np.float = float | |
| np.int = int | |
| np.object = object | |
| np.bool = bool | |
| dir = os.getcwd()+ '/uploads/' | |
| inp = gr.Image(type="pil") | |
| output = gr.Image(type="pil") | |
| examples=[[dir +"cafe_fall.mp4","Fall in cafe"], | |
| [dir +"slip.mp4","Run and Fall2"], | |
| [dir +"skate.mp4","Skate and Fall"], | |
| [dir +"kitchen.mp4","Fall in kitchen"], | |
| [dir +"studycam.mp4","Experiment fall"]] | |
| ckpt_path = os.getcwd() + "/checkpoints/best181-8376/ckpt_latest.pth" | |
| best_model = models.get('yolo_nas_s', | |
| num_classes=1, | |
| checkpoint_path=ckpt_path) | |
| best_model = best_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| #best_model = models.get("yolo_nas_s", pretrained_weights="coco") | |
| best_model.eval() | |
| #### Initiatize tracker | |
| tracker_model = os.getcwd() + "/checkpoints/ckpt.t7" | |
| tracker = DeepSort(model_path=tracker_model,max_age=30,nn_budget=100, max_iou_distance=0.7, max_dist=0.2) | |
| out_path=dir | |
| filename = 'demo.webm' | |
| description = "Yolo model to detect if a person is falling or fallen with deepsort to track how long the subject has fallen.\ | |
| If the duration crosses a threshold of 5s, the bounding box will turn red and the subject be labelled as IMMOBILE." | |
| def vid_predict(media): | |
| pipeline = DetectionPipeline( | |
| model=best_model, | |
| image_processor=best_model._image_processor, | |
| post_prediction_callback=best_model.get_post_prediction_callback(iou=0.25, conf=0.70, | |
| nms_top_k=100, # Example value, adjust based on your needs | |
| max_predictions=50, # Example value, adjust based on your needs | |
| multi_label_per_box=False, # Example value, adjust based on your needs | |
| class_agnostic_nms=False), | |
| class_names=best_model._class_names, | |
| ) | |
| print("Running Predict") | |
| save_to = os.path.join(out_path, filename) | |
| cap = cv2.VideoCapture(media) | |
| if cap.isOpened(): | |
| width = cap.get(3) # float `widtqh` | |
| print('width',width) | |
| height = cap.get(4) | |
| print('Height',height) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| # or | |
| fps = cap.get(5) | |
| print('fps:', fps) # float `fps` | |
| frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| # or frame_count = cap.get(7) | |
| print('frames count:', frame_count) # float `frame_count` | |
| out = cv2.VideoWriter(save_to, cv2.VideoWriter_fourcc(*'VP08'), fps, (640,640)) | |
| fall_records = {} | |
| frame_id = 0 | |
| while True: | |
| frame_id += 1 | |
| if frame_id > frame_count: | |
| break | |
| print('frame_id', frame_id) | |
| ret, img = cap.read() | |
| #img = cv2.resize(img, (1280, 720),cv2.INTER_AREA) | |
| # if height > 720: | |
| # print("Reshaped") | |
| img = cv2.resize(img, (640, 640),cv2.INTER_AREA) | |
| width, height = img.shape[1], img.shape[0] | |
| ### recalibrate color channels to rgb for use in model prediction | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| overlay = img.copy() | |
| ### create list objects needed for tracking | |
| detects = [] | |
| conffs = [] | |
| if ret: | |
| print("START ") | |
| model_predictions = get_prediction(best_model, img_rgb, pipeline) | |
| print(model_predictions) | |
| classnames = ['Fall-Detected'] | |
| results = model_predictions | |
| bboxes = results.bboxes_xyxy | |
| if len(bboxes) >= 1: | |
| confs = results.confidence | |
| labels = results.labels | |
| for bbox, conf, label in zip(bboxes, confs, labels): | |
| label = int(label) | |
| conf = np.round(conf, decimals=2) | |
| x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3] | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| ### for tracking model | |
| bw = abs(x1 - x2) | |
| bh = abs(y1 - y2) | |
| cx , cy = x1 + bw//2, y1 + bh//2 | |
| coords = [cx, cy, bw, bh] | |
| detects.append(coords) | |
| conffs.append([float(conf)]) | |
| ### Tracker | |
| xywhs = torch.tensor(detects) | |
| conffs = torch.tensor(conffs) | |
| #tracker_results = deepsort.update(xywhs, confss,oids, img) | |
| tracker_results = tracker.update(xywhs, conffs, img_rgb) | |
| ### conduct check on track_records | |
| now = datetime.datetime.now() | |
| if len(fall_records.keys()) >=1: | |
| #print(fall_records) | |
| ### reset timer for calculating immobility to 0 if time lapsed since last detection of fall more than N seconds | |
| fall_records = {id: item if (now - item['present']).total_seconds() <= 3.0 else {'start':now, 'present': now} for id, item in fall_records.items() } | |
| if len(tracker_results)>=1: | |
| for track,conf,label in zip(tracker_results,conffs, labels): | |
| conf = conf.numpy()[0] | |
| duration = 0 | |
| minute = 0 | |
| sec = 0 | |
| x1, y1 ,x2, y2, id = track | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| if id in fall_records.keys(): | |
| ### record present time | |
| present = datetime.datetime.now() | |
| fall_records[id].update({'present': present}) | |
| ### calculate duration | |
| duration = fall_records[id]['present'] - fall_records[id]['start'] | |
| duration = int(duration.total_seconds()) | |
| ### record status | |
| fall_records[id].update({'status': 'IMMOBILE'}) if duration >= 5 else fall_records[id].update({'status': None}) | |
| print(f"Frame:{frame_id} ID: {id} Conf: {conf} Duration:{duration} Status: {fall_records[id]['status']}") | |
| print(fall_records[id]) | |
| minute, sec = divmod(duration,60) | |
| else: | |
| start = datetime.datetime.now() | |
| fall_records[id] = {'start': start} | |
| fall_records[id].update({'present': start}) | |
| classname = classnames[int(label)] | |
| color = get_color(id*20) | |
| if duration < 5: | |
| display_text = f"{str(classname)} ({str(id)}) {str(conf)} Elapsed: {round(minute)}min{round(sec)}s" | |
| (w, h), _ = cv2.getTextSize( | |
| display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 1) | |
| cv2.rectangle(img,(x1, y1), (x2, y2),color,1) | |
| cv2.rectangle(overlay,(x1, y1), (x2, y2),color,1) | |
| cv2.rectangle(overlay, (min(x1,int(width)-w), max(1,y1 - 20)), (min(x1+ w,int(width)) , max(21,y1)), color, cv2.FILLED) | |
| else: | |
| display_text = f"{str(classname)} ({str(id)}) {str(conf)} IMMOBILE: {round(minute)}min{round(sec)}s " | |
| (w, h), _ = cv2.getTextSize( | |
| display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 1) | |
| cv2.rectangle(img,(x1, y1), (x2, y2),(0,0,255),1) | |
| cv2.rectangle(overlay,(x1, y1), (x2, y2),(0,0,255),1) | |
| cv2.rectangle(overlay, (min(x1,int(width)-w), max(1,y1 - 20)), (min(x1+ w,int(width)) , max(21,y1)), (0,0,255), cv2.FILLED) | |
| cv2.putText(img,display_text, (min(x1,int(width)-w), max(21,y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0),2) | |
| cv2.putText(overlay,display_text, (min(x1,int(width)-w), max(21,y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0),2) | |
| alpha = 0.6 | |
| masked = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0) | |
| out.write(masked) | |
| cap.release() | |
| out.release() | |
| cv2.destroyAllWindows() | |
| return save_to | |
| def run(): | |
| demo = gr.Interface(fn=vid_predict, inputs=gr.Video(format='mp4'), outputs=gr.Video(), examples=examples, description=description,cache_examples=False, title='Fall detection and tracking with deep sort') | |
| demo.launch(server_port=7860) | |
| if __name__ == "__main__": | |
| run() | |