Abubakar740 commited on
Commit
1f6e06f
·
1 Parent(s): 27e651e

update app

Browse files
Files changed (1) hide show
  1. main.py +84 -54
main.py CHANGED
@@ -4,6 +4,7 @@ import cv2
4
  import torch
5
  import torch.nn as nn
6
  import numpy as np
 
7
  from collections import deque
8
  from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException
9
  from fastapi.responses import FileResponse, RedirectResponse
@@ -21,11 +22,12 @@ CLIP_LEN = 32
21
  IMG_SIZE = 224
22
  THEFT_THRESHOLD = 0.6
23
 
24
- # Ensure directories exist
25
  os.makedirs(UPLOAD_DIR, exist_ok=True)
26
  os.makedirs(OUTPUT_DIR, exist_ok=True)
27
 
28
  # In-memory job store
 
29
  jobs = {}
30
 
31
  app = FastAPI(title="AI Theft Detection System")
@@ -45,9 +47,9 @@ if os.path.exists(MODEL_PATH):
45
  ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
46
  state_dict = ckpt["model"] if "model" in ckpt else ckpt
47
  slowfast_model.load_state_dict(state_dict)
48
- print("SlowFast weights loaded.")
49
  else:
50
- print(f"Warning: {MODEL_PATH} not found. Running with unitialized weights.")
51
 
52
  slowfast_model = slowfast_model.to(DEVICE).eval()
53
 
@@ -73,46 +75,48 @@ def draw_corner_rect(img, pt1, pt2, color, thickness, r, d):
73
  cv2.line(img, (x2, y2 - r), (x2, y2 - r - d), color, thickness)
74
  cv2.ellipse(img, (x2 - r, y2 - r), (r, r), 0, 0, 90, color, thickness)
75
 
76
- def draw_fancy_overlay(frame, avg_prob, theft_flag, frame_counter):
77
- h, w, _ = frame.shape
78
-
79
- # 1. Semi-transparent Header bar
 
 
 
 
80
  overlay = frame.copy()
81
- cv2.rectangle(overlay, (0, 0), (w, 80), (30, 30, 30), -1)
82
- cv2.addWeighted(overlay, 0.7, frame, 0.3, 0, frame)
83
-
84
- # 2. Scanning Dot (Pulsing)
85
- color_status = (0, 255, 0) if not theft_flag else (0, 0, 255)
86
- dot_alpha = (np.sin(frame_counter / 4) + 1) / 2
87
- if dot_alpha > 0.4:
88
- cv2.circle(frame, (40, 40), 10, color_status, -1)
89
- cv2.putText(frame, "AI SURVEILLANCE LIVE", (70, 48), cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), 1)
90
-
91
- # 3. Confidence Meter
92
- bar_x, bar_y, bar_w, bar_h = w - 350, 30, 300, 25
93
- cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + bar_h), (60, 60, 60), -1)
94
- fill_w = int(bar_w * avg_prob)
95
- # Color transitions: Green -> Orange -> Red
96
- bar_color = (0, 255, 0) if avg_prob < 0.4 else (0, 165, 255) if avg_prob < THEFT_THRESHOLD else (0, 0, 255)
97
- cv2.rectangle(frame, (bar_x, bar_y), (bar_x + fill_w, bar_y + bar_h), bar_color, -1)
98
- cv2.putText(frame, f"Risk Score: {int(avg_prob*100)}%", (bar_x, bar_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
99
-
100
- # 4. Theft Alert Banner
101
- if theft_flag:
102
- alert_overlay = frame.copy()
103
- cv2.rectangle(alert_overlay, (0, h//2 - 60), (w, h//2 + 60), (0, 0, 200), -1)
104
- cv2.addWeighted(alert_overlay, 0.5, frame, 0.5, 0, frame)
105
- cv2.putText(frame, "CRITICAL ALERT: THEFT DETECTED", (w//2 - 280, h//2 + 15),
106
- cv2.FONT_HERSHEY_TRIPLEX, 1.2, (255, 255, 255), 2)
 
 
107
 
108
  # --- PROCESSING LOGIC ---
109
 
110
  def preprocess(frames):
111
- processed = []
112
- for frame in frames:
113
- frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE))
114
- frame = frame[:, :, ::-1] / 255.0
115
- processed.append(frame)
116
  clip = np.transpose(np.array(processed), (3, 0, 1, 2))
117
  return torch.tensor(clip).float().unsqueeze(0)
118
 
@@ -128,6 +132,11 @@ def process_video_task(job_id: str, input_path: str, output_path: str):
128
  frame_counter = 0
129
 
130
  while cap.isOpened():
 
 
 
 
 
131
  ret, frame = cap.read()
132
  if not ret: break
133
  frame_counter += 1
@@ -139,7 +148,7 @@ def process_video_task(job_id: str, input_path: str, output_path: str):
139
  for r in results:
140
  if r.boxes is None: continue
141
  for box in r.boxes:
142
- if int(box.cls[0]) != 0: continue # Only Person
143
 
144
  x1, y1, x2, y2 = map(int, box.xyxy[0])
145
  crop = frame[y1:y2, x1:x2]
@@ -148,25 +157,27 @@ def process_video_task(job_id: str, input_path: str, output_path: str):
148
  frame_buffer.append(crop)
149
  if len(frame_buffer) == CLIP_LEN:
150
  clip = preprocess(frame_buffer).to(DEVICE)
151
- inputs = [clip[:, :, ::4, :, :], clip]
152
  with torch.no_grad():
153
- probs = torch.softmax(slowfast_model(inputs), dim=1)
154
  prediction_buffer.append(probs[0][1].item())
155
  avg_prob = np.mean(prediction_buffer)
156
 
157
- # Determine visual state
158
- active_theft = avg_prob > THEFT_THRESHOLD
159
- color = (0, 0, 255) if active_theft else (0, 255, 0)
160
  draw_corner_rect(frame, (x1, y1), (x2, y2), color, 2, 15, 25)
161
- if active_theft: theft_flag = True
162
 
163
- draw_fancy_overlay(frame, avg_prob, theft_flag, frame_counter)
 
 
164
  out.write(frame)
165
  jobs[job_id]["progress"] = int((frame_counter / total_frames) * 100)
166
 
167
  cap.release()
168
  out.release()
169
- jobs[job_id]["status"] = "completed"
 
 
170
  except Exception as e:
171
  jobs[job_id]["status"] = f"failed: {str(e)}"
172
 
@@ -176,29 +187,48 @@ def process_video_task(job_id: str, input_path: str, output_path: str):
176
  async def root():
177
  return RedirectResponse(url="/docs")
178
 
 
 
 
 
179
  @app.post("/detect")
180
  async def detect(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
181
  job_id = str(uuid.uuid4())
182
- input_path = os.path.join(UPLOAD_DIR, f"{job_id}_{file.filename}")
 
183
  output_path = os.path.join(OUTPUT_DIR, f"result_{job_id}.mp4")
184
 
185
  with open(input_path, "wb") as f:
186
  f.write(await file.read())
187
 
188
- jobs[job_id] = {"status": "processing", "progress": 0, "output_path": output_path}
189
- background_tasks.add_task(process_video_task, job_id, input_path, output_path)
 
 
 
 
 
190
 
191
- return {"job_id": job_id, "message": "Video analysis started"}
 
192
 
193
  @app.get("/status/{job_id}")
194
  async def get_status(job_id: str):
195
- if job_id not in jobs: raise HTTPException(404, "Job not found")
196
  return jobs[job_id]
197
 
 
 
 
 
 
 
 
 
198
  @app.get("/download/{job_id}")
199
  async def download(job_id: str):
200
- if job_id not in jobs or jobs[job_id]["status"] != "completed":
201
- raise HTTPException(400, "File not ready or job not found")
202
  return FileResponse(jobs[job_id]["output_path"], filename=f"analyzed_{job_id}.mp4")
203
 
204
  if __name__ == "__main__":
 
4
  import torch
5
  import torch.nn as nn
6
  import numpy as np
7
+ import datetime
8
  from collections import deque
9
  from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException
10
  from fastapi.responses import FileResponse, RedirectResponse
 
22
  IMG_SIZE = 224
23
  THEFT_THRESHOLD = 0.6
24
 
25
+ # Ensure directories exist and are writable
26
  os.makedirs(UPLOAD_DIR, exist_ok=True)
27
  os.makedirs(OUTPUT_DIR, exist_ok=True)
28
 
29
  # In-memory job store
30
+ # Structure: { job_id: { status, progress, output_path, stop_requested, start_time } }
31
  jobs = {}
32
 
33
  app = FastAPI(title="AI Theft Detection System")
 
47
  ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
48
  state_dict = ckpt["model"] if "model" in ckpt else ckpt
49
  slowfast_model.load_state_dict(state_dict)
50
+ print("SlowFast weights loaded successfully.")
51
  else:
52
+ print(f"Warning: {MODEL_PATH} not found. Running with uninitialized weights.")
53
 
54
  slowfast_model = slowfast_model.to(DEVICE).eval()
55
 
 
75
  cv2.line(img, (x2, y2 - r), (x2, y2 - r - d), color, thickness)
76
  cv2.ellipse(img, (x2 - r, y2 - r), (r, r), 0, 0, 90, color, thickness)
77
 
78
+ def draw_security_card(frame, avg_prob, theft_flag):
79
+ # Card Settings
80
+ card_x1, card_y1 = 30, 30
81
+ card_w, card_h = 500, 180
82
+ card_x2, card_y2 = card_x1 + card_w, card_y1 + card_h
83
+ orange_color = (0, 165, 255) # BGR Orange
84
+
85
+ # 1. Draw Semi-Transparent Background
86
  overlay = frame.copy()
87
+ cv2.rectangle(overlay, (card_x1, card_y1), (card_x2, card_y2), (30, 30, 30), -1)
88
+ cv2.addWeighted(overlay, 0.8, frame, 0.2, 0, frame)
89
+
90
+ # 2. Draw Orange Border
91
+ cv2.rectangle(frame, (card_x1, card_y1), (card_x2, card_y2), orange_color, 2)
92
+
93
+ # 3. Text Info
94
+ now = datetime.datetime.now().strftime("%b %d, %Y, %I:%M:%S %p")
95
+ status_text = "ALERT: THEFT DETECTED" if theft_flag else "STATUS: NO THEFT"
96
+ status_color = (0, 0, 255) if theft_flag else (220, 220, 220)
97
+
98
+ cv2.putText(frame, "THEFT MONITORING SYSTEM", (card_x1 + 20, card_y1 + 40),
99
+ cv2.FONT_HERSHEY_DUPLEX, 0.8, (255, 255, 255), 2)
100
+
101
+ cv2.putText(frame, status_text, (card_x1 + 20, card_y1 + 85),
102
+ cv2.FONT_HERSHEY_DUPLEX, 1.0, status_color, 2)
103
+
104
+ cv2.putText(frame, f"TIME: {now}", (card_x1 + 20, card_y1 + 125),
105
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
106
+
107
+ cv2.putText(frame, "THEFT DETECTION: ON", (card_x1 + 20, card_y1 + 160),
108
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
109
+
110
+ # 4. Small Risk Bar inside card
111
+ bar_full_w = card_w - 40
112
+ fill_w = int(bar_full_w * avg_prob)
113
+ cv2.rectangle(frame, (card_x1 + 20, card_y2 - 15), (card_x1 + 20 + bar_full_w, card_y2 - 10), (50, 50, 50), -1)
114
+ cv2.rectangle(frame, (card_x1 + 20, card_y2 - 15), (card_x1 + 20 + fill_w, card_y2 - 10), orange_color, -1)
115
 
116
  # --- PROCESSING LOGIC ---
117
 
118
  def preprocess(frames):
119
+ processed = [cv2.resize(f, (IMG_SIZE, IMG_SIZE))[:, :, ::-1] / 255.0 for f in frames]
 
 
 
 
120
  clip = np.transpose(np.array(processed), (3, 0, 1, 2))
121
  return torch.tensor(clip).float().unsqueeze(0)
122
 
 
132
  frame_counter = 0
133
 
134
  while cap.isOpened():
135
+ # Check for Stop Request
136
+ if jobs.get(job_id, {}).get("stop_requested"):
137
+ jobs[job_id]["status"] = "stopped"
138
+ break
139
+
140
  ret, frame = cap.read()
141
  if not ret: break
142
  frame_counter += 1
 
148
  for r in results:
149
  if r.boxes is None: continue
150
  for box in r.boxes:
151
+ if int(box.cls[0]) != 0: continue # Person only
152
 
153
  x1, y1, x2, y2 = map(int, box.xyxy[0])
154
  crop = frame[y1:y2, x1:x2]
 
157
  frame_buffer.append(crop)
158
  if len(frame_buffer) == CLIP_LEN:
159
  clip = preprocess(frame_buffer).to(DEVICE)
 
160
  with torch.no_grad():
161
+ probs = torch.softmax(slowfast_model([clip[:, :, ::4, :, :], clip]), dim=1)
162
  prediction_buffer.append(probs[0][1].item())
163
  avg_prob = np.mean(prediction_buffer)
164
 
165
+ is_theft = avg_prob > THEFT_THRESHOLD
166
+ color = (0, 0, 255) if is_theft else (0, 255, 0)
 
167
  draw_corner_rect(frame, (x1, y1), (x2, y2), color, 2, 15, 25)
168
+ if is_theft: theft_flag = True
169
 
170
+ # Draw the Security Card UI
171
+ draw_security_card(frame, avg_prob, theft_flag)
172
+
173
  out.write(frame)
174
  jobs[job_id]["progress"] = int((frame_counter / total_frames) * 100)
175
 
176
  cap.release()
177
  out.release()
178
+ if jobs[job_id]["status"] != "stopped":
179
+ jobs[job_id]["status"] = "completed"
180
+
181
  except Exception as e:
182
  jobs[job_id]["status"] = f"failed: {str(e)}"
183
 
 
187
  async def root():
188
  return RedirectResponse(url="/docs")
189
 
190
+ @app.get("/jobs")
191
+ async def list_jobs():
192
+ return [{"job_id": jid, "status": data["status"], "progress": f"{data['progress']}%"} for jid, data in jobs.items()]
193
+
194
  @app.post("/detect")
195
  async def detect(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
196
  job_id = str(uuid.uuid4())
197
+ input_filename = f"{job_id}_{file.filename}"
198
+ input_path = os.path.join(UPLOAD_DIR, input_filename)
199
  output_path = os.path.join(OUTPUT_DIR, f"result_{job_id}.mp4")
200
 
201
  with open(input_path, "wb") as f:
202
  f.write(await file.read())
203
 
204
+ jobs[job_id] = {
205
+ "status": "processing",
206
+ "progress": 0,
207
+ "output_path": output_path,
208
+ "stop_requested": False,
209
+ "filename": file.filename
210
+ }
211
 
212
+ background_tasks.add_task(process_video_task, job_id, input_path, output_path)
213
+ return {"job_id": job_id, "message": "Video analysis queued"}
214
 
215
  @app.get("/status/{job_id}")
216
  async def get_status(job_id: str):
217
+ if job_id not in jobs: raise HTTPException(404, "Job ID not found")
218
  return jobs[job_id]
219
 
220
+ @app.post("/stop/{job_id}")
221
+ async def stop_job(job_id: str):
222
+ if job_id not in jobs: raise HTTPException(404, "Job ID not found")
223
+ if jobs[job_id]["status"] == "processing":
224
+ jobs[job_id]["stop_requested"] = True
225
+ return {"message": "Stop signal sent to processing thread."}
226
+ return {"message": f"Job is already {jobs[job_id]['status']}"}
227
+
228
  @app.get("/download/{job_id}")
229
  async def download(job_id: str):
230
+ if job_id not in jobs or jobs[job_id]["status"] not in ["completed", "stopped"]:
231
+ raise HTTPException(400, "File not ready for download")
232
  return FileResponse(jobs[job_id]["output_path"], filename=f"analyzed_{job_id}.mp4")
233
 
234
  if __name__ == "__main__":