asrcoddeploy commited on
Commit
4128971
·
verified ·
1 Parent(s): 574d07a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ import os
8
+
9
+ # --- 1. MODEL ARCHITECTURE ---
10
+ class LDobjModel(nn.Module):
11
+ def __init__(self):
12
+ super(LDobjModel, self).__init__()
13
+ self.enc1 = self.conv_block(3, 16); self.pool1 = nn.MaxPool2d(2)
14
+ self.enc2 = self.conv_block(16, 32); self.pool2 = nn.MaxPool2d(2)
15
+ self.bottleneck = self.conv_block(32, 64)
16
+ self.up1 = nn.ConvTranspose2d(64, 32, 2, 2)
17
+ self.dec1 = self.conv_block(64, 32)
18
+ self.up2 = nn.ConvTranspose2d(32, 16, 2, 2)
19
+ self.dec2 = self.conv_block(32, 16)
20
+ self.final = nn.Sequential(nn.Conv2d(16, 1, 1), nn.Sigmoid())
21
+
22
+ def conv_block(self, in_c, out_c):
23
+ return nn.Sequential(nn.Conv2d(in_c, out_c, 3, 1, 1), nn.ReLU(),
24
+ nn.Conv2d(out_c, out_c, 3, 1, 1), nn.ReLU())
25
+
26
+ def forward(self, x):
27
+ e1 = self.enc1(x); e2 = self.enc2(self.pool1(e1))
28
+ b = self.bottleneck(self.pool2(e2))
29
+ d1 = torch.cat((e2, self.up1(b)), dim=1); d1 = self.dec1(d1)
30
+ d2 = torch.cat((e1, self.up2(d1)), dim=1); d2 = self.dec2(d2)
31
+ return self.final(d2)
32
+
33
+ # --- 2. LOAD AI ON STARTUP ---
34
+ device = torch.device('cpu') # Hugging Face Free Tier uses CPU
35
+ model = LDobjModel().to(device)
36
+ # Load weights (Make sure the filename matches exactly what you uploaded)
37
+ model.load_state_dict(torch.load('LDobj_weights.pth', map_location=device))
38
+ model.eval()
39
+
40
+ transform = transforms.Compose([
41
+ transforms.ToPILImage(),
42
+ transforms.Resize((288, 800)),
43
+ transforms.ToTensor()
44
+ ])
45
+
46
+ # --- 3. VIDEO PROCESSING LOGIC ---
47
+ def analyze_video(input_video_path):
48
+ if input_video_path is None:
49
+ return None
50
+
51
+ cap = cv2.VideoCapture(input_video_path)
52
+
53
+ # Get video specs
54
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
55
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
56
+ fps = cap.get(cv2.CAP_PROP_FPS)
57
+
58
+ # Setup output writer
59
+ raw_output = "raw_output.mp4"
60
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
61
+ out = cv2.VideoWriter(raw_output, fourcc, fps, (width, height))
62
+
63
+ while cap.isOpened():
64
+ ret, frame = cap.read()
65
+ if not ret: break
66
+
67
+ # Pre-process frame
68
+ input_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
69
+ img_tensor = transform(input_img).unsqueeze(0).to(device)
70
+
71
+ # AI Prediction
72
+ with torch.no_grad():
73
+ pred = model(img_tensor).squeeze().numpy()
74
+
75
+ # Binary Mask
76
+ mask = (pred > 0.5).astype(np.uint8)
77
+ mask_full = cv2.resize(mask, (width, height))
78
+
79
+ # Departure Logic
80
+ moments = cv2.moments(mask_full[int(height*0.8):, :])
81
+ alert_triggered = False
82
+
83
+ if moments["m00"] > 0:
84
+ lane_center_x = int(moments["m10"] / moments["m00"])
85
+ car_center_x = width // 2
86
+
87
+ # If car drifts > 10% of screen width
88
+ if abs(lane_center_x - car_center_x) > (width * 0.1):
89
+ alert_triggered = True
90
+
91
+ # ONLY MODIFY FRAME IF ALERT IS HAPPENING
92
+ if alert_triggered:
93
+ status_color = (0, 0, 255) # Red BGR
94
+ overlay = frame.copy()
95
+ overlay[mask_full > 0] = status_color
96
+
97
+ # Add UI Text
98
+ cv2.putText(frame, "WARNING: LANE DEPARTURE!", (width//10, 100),
99
+ cv2.FONT_HERSHEY_SIMPLEX, 1.5, status_color, 4)
100
+
101
+ # Blend frame with red lanes
102
+ final_frame = cv2.addWeighted(frame, 0.7, overlay, 0.3, 0)
103
+ out.write(final_frame)
104
+ else:
105
+ # Normal driving: return the clean, untouched dashcam footage
106
+ out.write(frame)
107
+
108
+ cap.release()
109
+ out.write(frame)
110
+ out.release()
111
+
112
+ # Convert to standard H264 for web browsers (Gradio requires this)
113
+ web_output = "final_output.mp4"
114
+ os.system(f"ffmpeg -y -i {raw_output} -vcodec libx264 {web_output}")
115
+
116
+ return web_output
117
+
118
+ # --- 4. GRADIO WEB INTERFACE ---
119
+ with gr.Blocks(theme=gr.themes.Monochrome()) as app:
120
+ gr.Markdown("# 🚗 LDobj: AI Lane Departure Alert System")
121
+ gr.Markdown("Upload a dashcam video. The AI will analyze the footage and **only overlay an alert** during actual lane departures.")
122
+
123
+ with gr.Row():
124
+ with gr.Column():
125
+ video_input = gr.Video(label="Upload Dashcam Video (.mp4)")
126
+ submit_btn = gr.Button("Analyze Video", variant="primary")
127
+
128
+ with gr.Column():
129
+ video_output = gr.Video(label="AI Analyzed Output")
130
+
131
+ submit_btn.click(fn=analyze_video, inputs=video_input, outputs=video_output)
132
+
133
+ app.launch()