Spaces:
Running
Running
File size: 7,964 Bytes
996d929 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | import gradio as gr
import cv2
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
import os
import time
# --- 1. MODEL ARCHITECTURE ---
class LDobjModel(nn.Module):
def __init__(self):
super(LDobjModel, self).__init__()
self.enc1 = self.conv_block(3, 16); self.pool1 = nn.MaxPool2d(2)
self.enc2 = self.conv_block(16, 32); self.pool2 = nn.MaxPool2d(2)
self.bottleneck = self.conv_block(32, 64)
self.up1 = nn.ConvTranspose2d(64, 32, 2, 2)
self.dec1 = self.conv_block(64, 32)
self.up2 = nn.ConvTranspose2d(32, 16, 2, 2)
self.dec2 = self.conv_block(32, 16)
self.final = nn.Sequential(nn.Conv2d(16, 1, 1), nn.Sigmoid())
def conv_block(self, in_c, out_c):
return nn.Sequential(nn.Conv2d(in_c, out_c, 3, 1, 1), nn.ReLU(),
nn.Conv2d(out_c, out_c, 3, 1, 1), nn.ReLU())
def forward(self, x):
e1 = self.enc1(x); e2 = self.enc2(self.pool1(e1))
b = self.bottleneck(self.pool2(e2))
d1 = torch.cat((e2, self.up1(b)), dim=1); d1 = self.dec1(d1)
d2 = torch.cat((e1, self.up2(d1)), dim=1); d2 = self.dec2(d2)
return self.final(d2)
# --- 2. INITIALIZATION ---
device = torch.device('cpu')
model = LDobjModel().to(device)
if os.path.exists('LDobj_weights.pth'):
model.load_state_dict(torch.load('LDobj_weights.pth', map_location=device))
model.eval()
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((288, 800)),
transforms.ToTensor()
])
# --- 3. ROBUST PROCESSING LOGIC (Temporal Smoothing) ---
def analyze_video(input_video_path, sensitivity, required_frames, progress=gr.Progress()):
if not input_video_path:
return None, "⚠️ Please upload a video first."
start_time = time.time()
cap = cv2.VideoCapture(input_video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
raw_output = "temp_raw.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(raw_output, fourcc, fps, (width, height))
morph_kernel = np.ones((5, 5), np.uint8)
drift_threshold = width * (sensitivity / 100.0)
frame_count = 0
alerts_triggered = 0
# NEW: Temporal variables to track sustained drift
consecutive_drift_frames = 0
is_currently_alerting = False
while cap.isOpened():
ret, frame = cap.read()
if not ret: break
frame_count += 1
if frame_count % 5 == 0:
progress(frame_count / total_frames, desc=f"Analyzing Frame {frame_count}/{total_frames}")
# AI Prediction
input_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_tensor = transform(input_img).unsqueeze(0).to(device)
with torch.no_grad():
pred = model(img_tensor).squeeze().numpy()
# Mask Cleaning
mask = (pred > 0.5).astype(np.uint8)
mask_full = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
mask_full = cv2.morphologyEx(mask_full, cv2.MORPH_OPEN, morph_kernel)
# ---------------------------------------------------------
# NEW DEPARTURE LOGIC: Must be sustained to trigger
# ---------------------------------------------------------
moments = cv2.moments(mask_full[int(height*0.75):, :])
detected_drift_this_frame = False
if moments["m00"] > 0:
cx = int(moments["m10"] / moments["m00"])
if abs(cx - width // 2) > drift_threshold:
detected_drift_this_frame = True
# Temporal Smoothing Counters
if detected_drift_this_frame:
consecutive_drift_frames += 1
else:
# If the car centers itself, decrease the counter (cool down)
consecutive_drift_frames = max(0, consecutive_drift_frames - 2)
# Trigger the actual UI Alert ONLY if it meets the required frame count
if consecutive_drift_frames >= required_frames:
is_currently_alerting = True
elif consecutive_drift_frames == 0:
is_currently_alerting = False
# Draw the alert
if is_currently_alerting:
alerts_triggered += 1
overlay = frame.copy()
overlay[mask_full > 0] = (0, 0, 255)
frame = cv2.addWeighted(frame, 0.7, overlay, 0.3, 0)
# Serious UI Overlay
cv2.rectangle(frame, (0, 0), (width, 120), (0, 0, 0), -1)
cv2.putText(frame, "CRITICAL: SUSTAINED DEPARTURE", (30, 80),
cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 0, 255), 3)
# Draw a visual warning border around the whole video
cv2.rectangle(frame, (0, 0), (width, height), (0, 0, 255), 10)
out.write(frame)
cap.release()
out.release()
progress(0.95, desc="Optimizing Video for Web...")
web_output = "ldobj_final.mp4"
os.system(f"ffmpeg -y -i {raw_output} -c:v libx264 -preset fast -pix_fmt yuv420p -movflags +faststart {web_output}")
process_time = time.time() - start_time
avg_fps = frame_count / process_time if process_time > 0 else 0
telemetry_report = (
f"✅ Analysis Complete\n"
f"------------------------\n"
f"⏱️ Processing Time: {process_time:.1f} sec\n"
f"🚀 AI Speed: {avg_fps:.1f} FPS\n"
f"🚨 Critical Alert Frames: {alerts_triggered}"
)
return web_output, telemetry_report
# --- 4. ULTIMATE FRONTEND DESIGN ---
custom_css = """
#video-in, #video-out { min-height: 450px; border-radius: 10px; border: 1px solid #333; }
.gradio-container { max-width: 1200px !important; margin: auto; }
.glow-title { color: #ff4a4a; text-shadow: 0px 0px 15px rgba(255, 74, 74, 0.5); text-align: center; margin-bottom: 5px; }
.sub-title { text-align: center; color: #888; margin-top: 0px; margin-bottom: 30px; }
"""
with gr.Blocks() as app:
gr.HTML("<h1 class='glow-title'>🛡️ LDobj ADAS Command Center</h1>")
gr.HTML("<h3 class='sub-title'>Advanced Driver Assistance System • Neural Lane Tracking</h3>")
with gr.Group():
with gr.Row():
with gr.Column(scale=4):
gr.Markdown("### 1. Input Source")
video_in = gr.Video(label="Dashcam Feed", elem_id="video-in")
gr.Markdown("### 2. Serious Alert Parameters")
sensitivity_slider = gr.Slider(
minimum=5, maximum=30, value=12, step=1,
label="Drift Distance Threshold (%)",
info="How far off-center the car must be before it's considered drifting."
)
frames_slider = gr.Slider(
minimum=1, maximum=30, value=7, step=1,
label="Sustained Drift Timer (Frames)",
info="How many consecutive frames the car must be drifting before triggering the CRITICAL alert (prevents glitchy flashing)."
)
run_btn = gr.Button("INITIALIZE SCAN", variant="primary", size="lg")
with gr.Column(scale=5):
gr.Markdown("### Live Output Feed")
video_out = gr.Video(label="LDobj Processed Feed", interactive=False, autoplay=True, elem_id="video-out")
gr.Markdown("### System Telemetry")
telemetry_out = gr.Textbox(label="Analytics Console", lines=6, interactive=False)
run_btn.click(
fn=analyze_video,
inputs=[video_in, sensitivity_slider, frames_slider],
outputs=[video_out, telemetry_out]
)
if __name__ == "__main__":
app.launch(
theme=gr.themes.Glass(primary_hue="red"),
css=custom_css,
footer_links=[]
) |