File size: 6,413 Bytes
e52a7fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7306b6
e52a7fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import cv2
from ultralytics import YOLO
import gradio as gr

def get_model(path):
    return YOLO(path)

def format_time(seconds):
    # seconds to minutes:seconds
    return f"{int(seconds // 60)}:{seconds % 60:05.2f}"

def merge_crash_events(crash_events):
    # many crash events have the same start time
    if not crash_events:
        return []

    merged_events = [crash_events[0]] # start with the first crash event

    for current_start, current_end in crash_events[1:]:
        last_start, last_end = merged_events[-1] # get the last start and end
        if current_start - last_end <= 5.0: # check if less than merge time threshold
            merged_events[-1] = (last_start, max(last_end, current_end)) # update end time to the latest time
        else: 
            merged_events.append((current_start, current_end)) # no merging needed
    return merged_events

def video_classification(video_path,label_vid_output,crash_vid_output, model ,min_crash_duration=2.0): # min_crash_duration weeds out false flags
    # start video process
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    count = 0

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    labeled_out = cv2.VideoWriter(label_vid_output, fourcc, fps, (width, height))
    crash_out = cv2.VideoWriter(crash_vid_output, fourcc, fps, (width, height))

    # label dictionary
    label_counts = {
        "Crash": 0,
        "Flight": 0,
        "No drone": 0,
        "No signal": 0,
        "No started": 0,
        "Started": 0,
        "Unstable": 0,
        "Landing": 0,
        "Unknown": 0
    }

    is_crash = False # is current frame a crash
    crash_events = [] # store crash event tuples (start,end)
    crash_start_time = None # keep track of start time

    non_crash_frame_threshold = int(fps * 1.0) # number of frames to consider crash over
    non_crash_frame_count = 0 # count current non crash frames

    while True:
        ret, og_frame = cap.read() # read current frame
        if not ret:
            break

        print(f"\rProcessing frame {count + 1}/{total_frames}", end='', flush=True)
        
        # preprocess frame for prediction
        frame = cv2.resize(og_frame, (320, 320)) # reducing to help processing time on cpu
        #frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # get timestamp of the frame
        current_time_ms = cap.get(cv2.CAP_PROP_POS_MSEC)
        current_time_sec = current_time_ms / 1000.0

        # make prediction
        results = model.predict(source=frame, imgsz=640, verbose=False)
        current_label = "Unknown"

        # update frame's label
        if results and hasattr(results[0], 'probs') and results[0].probs is not None:
            top1_index = results[0].probs.top1

            if top1_index == 0:
                label_counts["Crash"] += 1
                current_label = "Crash"
            elif top1_index == 1:
                label_counts["Flight"] += 1
                current_label = "Flight"
            elif top1_index == 2:
                label_counts["No drone"] += 1
                current_label = "No drone"
            elif top1_index == 3:
                label_counts["No signal"] += 1
                current_label = "No signal"
            elif top1_index == 4:
                label_counts["No started"] += 1
                current_label = "No started"
            elif top1_index == 5:
                label_counts["Started"] += 1
                current_label = "Started"
            elif top1_index == 6:
                label_counts["Unstable"] += 1
                current_label = "Unstable"
            elif top1_index == 7:
                label_counts["Landing"] += 1
                current_label = "Landing"
            else:
                label_counts["Unknown"] += 1

        if current_label == "Crash":
            crash_out.write(og_frame)
            if not is_crash:
                # start new crash event
                crash_start_time = current_time_sec
                is_crash = True
            non_crash_frame_count = 0
        else:
            if is_crash:
                # currently in crash event
                non_crash_frame_count += 1
                if non_crash_frame_count >= non_crash_frame_threshold:
                    # end crash event if non_crash is more than threshold
                    crash_end_time = current_time_sec
                    crash_duration = crash_end_time - crash_start_time
                    if crash_duration >= min_crash_duration:
                        crash_events.append((crash_start_time, crash_end_time))
                    is_crash = False
                    crash_start_time = None
                    non_crash_frame_count = 0
            else:
                # not in a crash event
                non_crash_frame_count = 0

        # write label to frame
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.6
        font_color = (255, 255, 255)
        thickness = 2
        position = (10, 30)
        text_size = cv2.getTextSize(current_label, font, font_scale, thickness)[0]
        text_x, text_y = position
        cv2.rectangle(og_frame, (text_x - 5, text_y - text_size[1] - 5), (text_x + text_size[0] + 5, text_y + 5), (0, 0, 0), -1)
        cv2.putText(og_frame, current_label, position, font, font_scale, font_color, thickness)
        labeled_out.write(og_frame)

        frame_out = cv2.cvtColor(og_frame, cv2.COLOR_BGR2RGB)
        progress_text = f"Processing frame {count + 1}/{total_frames}"
        yield {'type': 'frame', 'frame': frame_out, 'progress_text': progress_text}

        count += 1 # frame is over
    
    # video is over
    if is_crash:
        # ended on a crash frame
        crash_end_time = total_frames / fps
        crash_duration = crash_end_time - crash_start_time
        if crash_duration >= min_crash_duration:
            crash_events.append((crash_start_time, crash_end_time))

    cap.release()
    labeled_out.release()
    crash_out.release()
    cv2.destroyAllWindows()

    # merge events to only have unique crashes
    merged_crash_events = merge_crash_events(crash_events)

    yield {'type': 'results', 'label_counts': label_counts, 'crash_events': merged_crash_events}