| import os |
| import cv2 |
| from ultralytics import YOLO |
| import gradio as gr |
|
|
| def get_model(path): |
| return YOLO(path) |
|
|
| def format_time(seconds): |
| |
| return f"{int(seconds // 60)}:{seconds % 60:05.2f}" |
|
|
| def merge_crash_events(crash_events): |
| |
| if not crash_events: |
| return [] |
|
|
| merged_events = [crash_events[0]] |
|
|
| for current_start, current_end in crash_events[1:]: |
| last_start, last_end = merged_events[-1] |
| if current_start - last_end <= 5.0: |
| merged_events[-1] = (last_start, max(last_end, current_end)) |
| else: |
| merged_events.append((current_start, current_end)) |
| return merged_events |
|
|
| def video_classification(video_path,label_vid_output,crash_vid_output, model ,min_crash_duration=2.0): |
| |
| cap = cv2.VideoCapture(video_path) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| fps = cap.get(cv2.CAP_PROP_FPS) |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| count = 0 |
|
|
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| labeled_out = cv2.VideoWriter(label_vid_output, fourcc, fps, (width, height)) |
| crash_out = cv2.VideoWriter(crash_vid_output, fourcc, fps, (width, height)) |
|
|
| |
| label_counts = { |
| "Crash": 0, |
| "Flight": 0, |
| "No drone": 0, |
| "No signal": 0, |
| "No started": 0, |
| "Started": 0, |
| "Unstable": 0, |
| "Landing": 0, |
| "Unknown": 0 |
| } |
|
|
| is_crash = False |
| crash_events = [] |
| crash_start_time = None |
|
|
| non_crash_frame_threshold = int(fps * 1.0) |
| non_crash_frame_count = 0 |
|
|
| while True: |
| ret, og_frame = cap.read() |
| if not ret: |
| break |
|
|
| print(f"\rProcessing frame {count + 1}/{total_frames}", end='', flush=True) |
| |
| |
| frame = cv2.resize(og_frame, (320, 320)) |
| |
|
|
| |
| current_time_ms = cap.get(cv2.CAP_PROP_POS_MSEC) |
| current_time_sec = current_time_ms / 1000.0 |
|
|
| |
| results = model.predict(source=frame, imgsz=640, verbose=False) |
| current_label = "Unknown" |
|
|
| |
| if results and hasattr(results[0], 'probs') and results[0].probs is not None: |
| top1_index = results[0].probs.top1 |
|
|
| if top1_index == 0: |
| label_counts["Crash"] += 1 |
| current_label = "Crash" |
| elif top1_index == 1: |
| label_counts["Flight"] += 1 |
| current_label = "Flight" |
| elif top1_index == 2: |
| label_counts["No drone"] += 1 |
| current_label = "No drone" |
| elif top1_index == 3: |
| label_counts["No signal"] += 1 |
| current_label = "No signal" |
| elif top1_index == 4: |
| label_counts["No started"] += 1 |
| current_label = "No started" |
| elif top1_index == 5: |
| label_counts["Started"] += 1 |
| current_label = "Started" |
| elif top1_index == 6: |
| label_counts["Unstable"] += 1 |
| current_label = "Unstable" |
| elif top1_index == 7: |
| label_counts["Landing"] += 1 |
| current_label = "Landing" |
| else: |
| label_counts["Unknown"] += 1 |
|
|
| if current_label == "Crash": |
| crash_out.write(og_frame) |
| if not is_crash: |
| |
| crash_start_time = current_time_sec |
| is_crash = True |
| non_crash_frame_count = 0 |
| else: |
| if is_crash: |
| |
| non_crash_frame_count += 1 |
| if non_crash_frame_count >= non_crash_frame_threshold: |
| |
| crash_end_time = current_time_sec |
| crash_duration = crash_end_time - crash_start_time |
| if crash_duration >= min_crash_duration: |
| crash_events.append((crash_start_time, crash_end_time)) |
| is_crash = False |
| crash_start_time = None |
| non_crash_frame_count = 0 |
| else: |
| |
| non_crash_frame_count = 0 |
|
|
| |
| font = cv2.FONT_HERSHEY_SIMPLEX |
| font_scale = 0.6 |
| font_color = (255, 255, 255) |
| thickness = 2 |
| position = (10, 30) |
| text_size = cv2.getTextSize(current_label, font, font_scale, thickness)[0] |
| text_x, text_y = position |
| cv2.rectangle(og_frame, (text_x - 5, text_y - text_size[1] - 5), (text_x + text_size[0] + 5, text_y + 5), (0, 0, 0), -1) |
| cv2.putText(og_frame, current_label, position, font, font_scale, font_color, thickness) |
| labeled_out.write(og_frame) |
|
|
| frame_out = cv2.cvtColor(og_frame, cv2.COLOR_BGR2RGB) |
| progress_text = f"Processing frame {count + 1}/{total_frames}" |
| yield {'type': 'frame', 'frame': frame_out, 'progress_text': progress_text} |
|
|
| count += 1 |
| |
| |
| if is_crash: |
| |
| crash_end_time = total_frames / fps |
| crash_duration = crash_end_time - crash_start_time |
| if crash_duration >= min_crash_duration: |
| crash_events.append((crash_start_time, crash_end_time)) |
|
|
| cap.release() |
| labeled_out.release() |
| crash_out.release() |
| cv2.destroyAllWindows() |
|
|
| |
| merged_crash_events = merge_crash_events(crash_events) |
|
|
| yield {'type': 'results', 'label_counts': label_counts, 'crash_events': merged_crash_events} |