File size: 6,556 Bytes
6eb9596
9453af9
 
 
 
 
6eb9596
 
 
 
 
 
 
 
9453af9
 
 
 
 
 
 
6eb9596
9453af9
 
 
 
 
 
 
 
 
 
 
 
6eb9596
9453af9
 
 
 
 
 
 
 
 
 
6eb9596
9453af9
6eb9596
9453af9
 
6eb9596
 
 
 
 
9453af9
 
 
6eb9596
 
9453af9
 
6eb9596
9453af9
6eb9596
9453af9
6eb9596
 
 
 
 
 
9453af9
 
 
 
 
 
6eb9596
 
 
9453af9
6eb9596
9453af9
 
 
6eb9596
 
9453af9
 
 
 
 
 
6eb9596
 
 
9453af9
6eb9596
9453af9
 
 
 
 
 
 
 
 
 
 
 
 
6eb9596
9453af9
 
6eb9596
 
9453af9
6eb9596
9453af9
 
 
 
6eb9596
9453af9
6eb9596
 
9453af9
 
 
 
 
 
 
 
6eb9596
9453af9
 
6eb9596
 
9453af9
 
6eb9596
9453af9
 
 
 
 
 
 
 
 
 
6eb9596
9453af9
 
 
 
 
 
 
6eb9596
 
 
 
 
9453af9
 
6eb9596
9453af9
 
 
6eb9596
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os, cv2, json, time, tempfile, numpy as np, gradio as gr
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment

# ------------------------------------------------------------
# ๐Ÿ”ง Safe-load fix for PyTorch 2.6
# ------------------------------------------------------------
import torch
import ultralytics.nn.tasks as ultralytics_tasks
torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])

# ------------------------------------------------------------
# โš™๏ธ Model + constants
# ------------------------------------------------------------
MODEL_PATH = "yolov8n.pt"
model = YOLO(MODEL_PATH)
VEHICLE_CLASSES = [2, 3, 5, 7]  # car, motorbike, bus, truck


# ------------------------------------------------------------
# ๐Ÿงฉ Kalman Tracker
# ------------------------------------------------------------
class Track:
    def __init__(self, bbox, tid):
        self.id = tid
        self.kf = KalmanFilter(dim_x=4, dim_z=2)
        self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
        self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
        self.kf.P *= 1000.0
        self.kf.R *= 10.0
        self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
        self.trace = []

    def centroid(self, b):
        x1,y1,x2,y2=b
        return [(x1+x2)/2,(y1+y2)/2]
    def predict(self): self.kf.predict(); return self.kf.x[:2].reshape(2)
    def update(self,b):
        z=np.array(self.centroid(b)).reshape(2,1)
        self.kf.update(z)
        cx,cy=self.kf.x[:2].reshape(2)
        self.trace.append((float(cx),float(cy)))
        return (cx,cy)


# ------------------------------------------------------------
# ๐Ÿงฎ Direction check
# ------------------------------------------------------------
def analyze_direction(trace, centers):
    if len(trace) < 3:
        return "NA", 1.0
    v = np.array(trace[-1]) - np.array(trace[-3])
    if np.linalg.norm(v) < 1e-6:
        return "NA", 1.0
    v = v / np.linalg.norm(v)
    sims = np.dot(centers, v)
    max_sim = np.max(sims)
    if max_sim < 0:
        return "WRONG", float(max_sim)
    return "OK", float(max_sim)


# ------------------------------------------------------------
# ๐Ÿšฆ Main Processing
# ------------------------------------------------------------
def load_flow_centers(flow_json):
    data = json.load(open(flow_json))
    centers = np.array(data["flow_centers"])
    centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
    return centers

def process_video(video_path, flow_json):
    centers = load_flow_centers(flow_json)
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS) or 25
    w,h = int(cap.get(3)), int(cap.get(4))

    out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(out_path, fourcc, fps, (w,h))

    tracks, next_id, log = [], 0, []

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        results = model(frame, verbose=False)[0]
        detections=[]
        for box in results.boxes:
            if int(box.cls) in VEHICLE_CLASSES and box.conf>0.3:
                detections.append(box.xyxy[0].cpu().numpy())

        # Predict existing
        predicted=[t.predict() for t in tracks]
        predicted=np.array(predicted) if len(predicted)>0 else np.empty((0,2))

        # Assign detections
        assigned=set()
        if len(predicted)>0 and len(detections)>0:
            cost=np.zeros((len(predicted),len(detections)))
            for i,p in enumerate(predicted):
                for j,d in enumerate(detections):
                    cx,cy=((d[0]+d[2])/2,(d[1]+d[3])/2)
                    cost[i,j]=np.linalg.norm(p-np.array([cx,cy]))
            r,c=linear_sum_assignment(cost)
            for i,j in zip(r,c):
                if cost[i,j]<80:
                    assigned.add(j)
                    tracks[i].update(detections[j])

        # New tracks
        for j,d in enumerate(detections):
            if j not in assigned:
                t=Track(d,next_id); next_id+=1
                t.update(d); tracks.append(t)

        # Draw results
        for trk in tracks:
            if len(trk.trace)<3: continue
            status, sim = analyze_direction(trk.trace, centers)
            x,y=map(int,trk.trace[-1])
            color=(0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (200,200,200))
            cv2.circle(frame,(x,y),4,color,-1)
            cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),
                        cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
            for i in range(1,len(trk.trace)):
                cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
                                (int(trk.trace[i][0]),int(trk.trace[i][1])),color,1)
            log.append({"id":trk.id,"status":status,"cos_sim":round(sim,3)})

        out.write(frame)

    cap.release(); out.release()

    log_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
    with open(log_path,"w") as f: json.dump(log,f,indent=2)
    return out_path, log_path


# ------------------------------------------------------------
# ๐Ÿ–ฅ๏ธ Gradio UI
# ------------------------------------------------------------
def run_app(video, flow_file):
    out_path, log_path = process_video(video, flow_file)
    log_data = json.load(open(log_path))
    summary = {"vehicles_analyzed": len(log_data)}
    return out_path, log_data, summary

description_text = """
### ๐Ÿšฆ Wrong-Direction Detection (Stage 3)
Uploads your traffic video and the **flow_stats.json** from Stage 2.  
Outputs an annotated video with โœ… OK / ๐Ÿšซ WRONG labels and a JSON log.
"""

example_vid = "10.mp4" if os.path.exists("10.mp4") else None
example_flow = "flow_stats.json" if os.path.exists("flow_stats.json") else None

demo = gr.Interface(
    fn=run_app,
    inputs=[gr.Video(label="Upload Traffic Video (.mp4)"),
            gr.File(label="Upload flow_stats.json (Stage 2 Output)")],
    outputs=[gr.Video(label="Violation Output Video"),
             gr.JSON(label="Per-Vehicle Log"),
             gr.JSON(label="Summary")],
    title="๐Ÿš— Wrong-Direction Detection โ€“ Stage 3",
    description=description_text,
    examples=None,     # disables example caching
)

if __name__ == "__main__":
    os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
    demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)