Abubakar740 commited on
Commit
27e651e
·
1 Parent(s): a8169fb

update app

Browse files
Files changed (1) hide show
  1. main.py +109 -105
main.py CHANGED
@@ -10,32 +10,28 @@ from fastapi.responses import FileResponse, RedirectResponse
10
  from pytorchvideo.models.hub import slowfast_r50
11
  from ultralytics import YOLO
12
 
13
- app = FastAPI()
14
-
15
-
16
- # Create absolute paths based on the app directory
17
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18
  UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
19
  OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
20
 
21
- # Ensure they exist (as a backup)
22
- os.makedirs(UPLOAD_DIR, exist_ok=True)
23
- os.makedirs(OUTPUT_DIR, exist_ok=True)
24
-
25
- # --- CONFIG & GLOBALS ---
26
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
- MODEL_PATH = "best_slowfast_theft.pth"
28
- UPLOAD_DIR = "uploads"
29
- OUTPUT_DIR = "outputs"
30
  CLIP_LEN = 32
31
  IMG_SIZE = 224
32
  THEFT_THRESHOLD = 0.6
33
 
34
- # In-memory store for job progress (In production, use Redis/Database)
35
- jobs = {}
 
 
 
 
 
 
36
 
37
- # --- LOAD MODELS GLOBALLY (Once) ---
38
- print(f"Loading models on {DEVICE}...")
39
  yolo = YOLO("yolov8n.pt")
40
 
41
  slowfast_model = slowfast_r50(pretrained=False)
@@ -45,99 +41,132 @@ slowfast_model.blocks[-1].proj = nn.Sequential(
45
  nn.Linear(in_features, 2)
46
  )
47
 
48
- # Load weights
49
- ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
50
- state_dict = ckpt["model"] if "model" in ckpt else ckpt
51
- slowfast_model.load_state_dict(state_dict)
 
 
 
 
52
  slowfast_model = slowfast_model.to(DEVICE).eval()
53
- print("Models loaded successfully.")
54
 
55
- # --- HELPER FUNCTIONS ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def preprocess(frames):
58
  processed = []
59
  for frame in frames:
60
  frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE))
61
- frame = frame[:, :, ::-1] # BGR to RGB
62
- frame = frame / 255.0
63
  processed.append(frame)
64
- clip = np.array(processed)
65
- clip = np.transpose(clip, (3, 0, 1, 2)) # C,T,H,W
66
  return torch.tensor(clip).float().unsqueeze(0)
67
 
68
  def process_video_task(job_id: str, input_path: str, output_path: str):
69
  try:
70
  cap = cv2.VideoCapture(input_path)
71
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
72
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
73
- fps = int(cap.get(cv2.CAP_PROP_FPS))
74
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
75
-
76
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
77
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
78
 
79
  frame_buffer = deque(maxlen=CLIP_LEN)
80
  prediction_buffer = deque(maxlen=10)
81
-
82
  frame_counter = 0
83
 
84
- while True:
85
  ret, frame = cap.read()
86
- if not ret:
87
- break
88
-
89
  frame_counter += 1
 
90
  theft_flag = False
91
  avg_prob = 0.0
92
-
93
  results = yolo(frame, verbose=False)
 
94
  for r in results:
95
  if r.boxes is None: continue
96
  for box in r.boxes:
97
- cls = int(box.cls[0])
98
- if cls != 0: continue # Person only
99
-
100
  x1, y1, x2, y2 = map(int, box.xyxy[0])
101
  crop = frame[y1:y2, x1:x2]
102
  if crop.size == 0: continue
103
 
104
  frame_buffer.append(crop)
105
-
106
  if len(frame_buffer) == CLIP_LEN:
107
  clip = preprocess(frame_buffer).to(DEVICE)
108
- # SlowFast inputs
109
  inputs = [clip[:, :, ::4, :, :], clip]
110
  with torch.no_grad():
111
- outputs = slowfast_model(inputs)
112
- probs = torch.softmax(outputs, dim=1)
113
-
114
- theft_prob = probs[0][1].item()
115
- prediction_buffer.append(theft_prob)
116
  avg_prob = np.mean(prediction_buffer)
117
 
118
- if avg_prob > THEFT_THRESHOLD:
119
- theft_flag = True
120
-
121
- color = (0, 0, 255) if theft_flag else (0, 255, 0)
122
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
123
-
124
- # UI Overlays
125
- card_text = f"Class: {'THEFT' if avg_prob > THEFT_THRESHOLD else 'Normal'} | Score: {avg_prob:.2f}"
126
- cv2.rectangle(frame, (10, 10), (310, 70), (50, 50, 50), -1)
127
- cv2.putText(frame, card_text, (20, 45), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
128
- if theft_flag:
129
- cv2.putText(frame, "THEFT ALERT", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 3)
130
 
 
131
  out.write(frame)
132
-
133
- # Update Progress
134
  jobs[job_id]["progress"] = int((frame_counter / total_frames) * 100)
135
 
136
  cap.release()
137
  out.release()
138
  jobs[job_id]["status"] = "completed"
139
- jobs[job_id]["progress"] = 100
140
-
141
  except Exception as e:
142
  jobs[job_id]["status"] = f"failed: {str(e)}"
143
 
@@ -148,55 +177,30 @@ async def root():
148
  return RedirectResponse(url="/docs")
149
 
150
  @app.post("/detect")
151
- async def detect_theft(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
152
  job_id = str(uuid.uuid4())
153
- input_filename = f"{job_id}_{file.filename}"
154
- input_path = os.path.join(UPLOAD_DIR, input_filename)
155
  output_path = os.path.join(OUTPUT_DIR, f"result_{job_id}.mp4")
156
 
157
- # Save uploaded file
158
- with open(input_path, "wb") as buffer:
159
- buffer.write(await file.read())
160
 
161
- # Initialize job state
162
- jobs[job_id] = {
163
- "status": "processing",
164
- "progress": 0,
165
- "output_path": output_path
166
- }
167
-
168
- # Run processing in background
169
  background_tasks.add_task(process_video_task, job_id, input_path, output_path)
170
-
171
- return {"job_id": job_id, "message": "Video processing started"}
172
-
173
 
174
  @app.get("/status/{job_id}")
175
  async def get_status(job_id: str):
176
- if job_id not in jobs:
177
- raise HTTPException(status_code=404, detail="Job ID not found")
178
-
179
- return {
180
- "job_id": job_id,
181
- "status": jobs[job_id]["status"],
182
- "progress": f"{jobs[job_id]['progress']}%"
183
- }
184
-
185
 
186
  @app.get("/download/{job_id}")
187
- async def download_video(job_id: str):
188
- if job_id not in jobs:
189
- raise HTTPException(status_code=404, detail="Job ID not found")
190
-
191
- if jobs[job_id]["status"] != "completed":
192
- raise HTTPException(status_code=400, detail="Video is not processed yet")
193
-
194
- return FileResponse(
195
- path=jobs[job_id]["output_path"],
196
- filename=f"annotated_{job_id}.mp4",
197
- media_type='video/mp4'
198
- )
199
 
200
  if __name__ == "__main__":
201
  import uvicorn
202
- uvicorn.run(app, host="0.0.0.0", port=8002)
 
10
  from pytorchvideo.models.hub import slowfast_r50
11
  from ultralytics import YOLO
12
 
13
+ # --- CONFIGURATION ---
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ MODEL_PATH = "best_slowfast_theft.pth"
 
16
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
  UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
18
  OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
19
 
 
 
 
 
 
 
 
 
 
20
  CLIP_LEN = 32
21
  IMG_SIZE = 224
22
  THEFT_THRESHOLD = 0.6
23
 
24
+ # Ensure directories exist
25
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
26
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
27
+
28
+ # In-memory job store
29
+ jobs = {}
30
+
31
+ app = FastAPI(title="AI Theft Detection System")
32
 
33
+ # --- MODEL LOADING ---
34
+ print(f"Loading Models on {DEVICE}...")
35
  yolo = YOLO("yolov8n.pt")
36
 
37
  slowfast_model = slowfast_r50(pretrained=False)
 
41
  nn.Linear(in_features, 2)
42
  )
43
 
44
+ if os.path.exists(MODEL_PATH):
45
+ ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
46
+ state_dict = ckpt["model"] if "model" in ckpt else ckpt
47
+ slowfast_model.load_state_dict(state_dict)
48
+ print("SlowFast weights loaded.")
49
+ else:
50
+ print(f"Warning: {MODEL_PATH} not found. Running with unitialized weights.")
51
+
52
  slowfast_model = slowfast_model.to(DEVICE).eval()
 
53
 
54
+ # --- VISUALIZATION HELPERS ---
55
+
56
+ def draw_corner_rect(img, pt1, pt2, color, thickness, r, d):
57
+ x1, y1 = pt1
58
+ x2, y2 = pt2
59
+ # Top Left
60
+ cv2.line(img, (x1 + r, y1), (x1 + r + d, y1), color, thickness)
61
+ cv2.line(img, (x1, y1 + r), (x1, y1 + r + d), color, thickness)
62
+ cv2.ellipse(img, (x1 + r, y1 + r), (r, r), 180, 0, 90, color, thickness)
63
+ # Top Right
64
+ cv2.line(img, (x2 - r, y1), (x2 - r - d, y1), color, thickness)
65
+ cv2.line(img, (x2, y1 + r), (x2, y1 + r + d), color, thickness)
66
+ cv2.ellipse(img, (x2 - r, y1 + r), (r, r), 270, 0, 90, color, thickness)
67
+ # Bottom Left
68
+ cv2.line(img, (x1 + r, y2), (x1 + r + d, y2), color, thickness)
69
+ cv2.line(img, (x1, y2 - r), (x1, y2 - r - d), color, thickness)
70
+ cv2.ellipse(img, (x1 + r, y2 - r), (r, r), 90, 0, 90, color, thickness)
71
+ # Bottom Right
72
+ cv2.line(img, (x2 - r, y2), (x2 - r - d, y2), color, thickness)
73
+ cv2.line(img, (x2, y2 - r), (x2, y2 - r - d), color, thickness)
74
+ cv2.ellipse(img, (x2 - r, y2 - r), (r, r), 0, 0, 90, color, thickness)
75
+
76
+ def draw_fancy_overlay(frame, avg_prob, theft_flag, frame_counter):
77
+ h, w, _ = frame.shape
78
+
79
+ # 1. Semi-transparent Header bar
80
+ overlay = frame.copy()
81
+ cv2.rectangle(overlay, (0, 0), (w, 80), (30, 30, 30), -1)
82
+ cv2.addWeighted(overlay, 0.7, frame, 0.3, 0, frame)
83
+
84
+ # 2. Scanning Dot (Pulsing)
85
+ color_status = (0, 255, 0) if not theft_flag else (0, 0, 255)
86
+ dot_alpha = (np.sin(frame_counter / 4) + 1) / 2
87
+ if dot_alpha > 0.4:
88
+ cv2.circle(frame, (40, 40), 10, color_status, -1)
89
+ cv2.putText(frame, "AI SURVEILLANCE LIVE", (70, 48), cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), 1)
90
+
91
+ # 3. Confidence Meter
92
+ bar_x, bar_y, bar_w, bar_h = w - 350, 30, 300, 25
93
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + bar_h), (60, 60, 60), -1)
94
+ fill_w = int(bar_w * avg_prob)
95
+ # Color transitions: Green -> Orange -> Red
96
+ bar_color = (0, 255, 0) if avg_prob < 0.4 else (0, 165, 255) if avg_prob < THEFT_THRESHOLD else (0, 0, 255)
97
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + fill_w, bar_y + bar_h), bar_color, -1)
98
+ cv2.putText(frame, f"Risk Score: {int(avg_prob*100)}%", (bar_x, bar_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
99
+
100
+ # 4. Theft Alert Banner
101
+ if theft_flag:
102
+ alert_overlay = frame.copy()
103
+ cv2.rectangle(alert_overlay, (0, h//2 - 60), (w, h//2 + 60), (0, 0, 200), -1)
104
+ cv2.addWeighted(alert_overlay, 0.5, frame, 0.5, 0, frame)
105
+ cv2.putText(frame, "CRITICAL ALERT: THEFT DETECTED", (w//2 - 280, h//2 + 15),
106
+ cv2.FONT_HERSHEY_TRIPLEX, 1.2, (255, 255, 255), 2)
107
+
108
+ # --- PROCESSING LOGIC ---
109
 
110
  def preprocess(frames):
111
  processed = []
112
  for frame in frames:
113
  frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE))
114
+ frame = frame[:, :, ::-1] / 255.0
 
115
  processed.append(frame)
116
+ clip = np.transpose(np.array(processed), (3, 0, 1, 2))
 
117
  return torch.tensor(clip).float().unsqueeze(0)
118
 
119
  def process_video_task(job_id: str, input_path: str, output_path: str):
120
  try:
121
  cap = cv2.VideoCapture(input_path)
122
+ w, h = int(cap.get(3)), int(cap.get(4))
123
+ fps, total_frames = int(cap.get(5)), int(cap.get(7))
124
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
 
 
 
 
125
 
126
  frame_buffer = deque(maxlen=CLIP_LEN)
127
  prediction_buffer = deque(maxlen=10)
 
128
  frame_counter = 0
129
 
130
+ while cap.isOpened():
131
  ret, frame = cap.read()
132
+ if not ret: break
 
 
133
  frame_counter += 1
134
+
135
  theft_flag = False
136
  avg_prob = 0.0
 
137
  results = yolo(frame, verbose=False)
138
+
139
  for r in results:
140
  if r.boxes is None: continue
141
  for box in r.boxes:
142
+ if int(box.cls[0]) != 0: continue # Only Person
143
+
 
144
  x1, y1, x2, y2 = map(int, box.xyxy[0])
145
  crop = frame[y1:y2, x1:x2]
146
  if crop.size == 0: continue
147
 
148
  frame_buffer.append(crop)
 
149
  if len(frame_buffer) == CLIP_LEN:
150
  clip = preprocess(frame_buffer).to(DEVICE)
 
151
  inputs = [clip[:, :, ::4, :, :], clip]
152
  with torch.no_grad():
153
+ probs = torch.softmax(slowfast_model(inputs), dim=1)
154
+ prediction_buffer.append(probs[0][1].item())
 
 
 
155
  avg_prob = np.mean(prediction_buffer)
156
 
157
+ # Determine visual state
158
+ active_theft = avg_prob > THEFT_THRESHOLD
159
+ color = (0, 0, 255) if active_theft else (0, 255, 0)
160
+ draw_corner_rect(frame, (x1, y1), (x2, y2), color, 2, 15, 25)
161
+ if active_theft: theft_flag = True
 
 
 
 
 
 
 
162
 
163
+ draw_fancy_overlay(frame, avg_prob, theft_flag, frame_counter)
164
  out.write(frame)
 
 
165
  jobs[job_id]["progress"] = int((frame_counter / total_frames) * 100)
166
 
167
  cap.release()
168
  out.release()
169
  jobs[job_id]["status"] = "completed"
 
 
170
  except Exception as e:
171
  jobs[job_id]["status"] = f"failed: {str(e)}"
172
 
 
177
  return RedirectResponse(url="/docs")
178
 
179
  @app.post("/detect")
180
+ async def detect(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
181
  job_id = str(uuid.uuid4())
182
+ input_path = os.path.join(UPLOAD_DIR, f"{job_id}_{file.filename}")
 
183
  output_path = os.path.join(OUTPUT_DIR, f"result_{job_id}.mp4")
184
 
185
+ with open(input_path, "wb") as f:
186
+ f.write(await file.read())
 
187
 
188
+ jobs[job_id] = {"status": "processing", "progress": 0, "output_path": output_path}
 
 
 
 
 
 
 
189
  background_tasks.add_task(process_video_task, job_id, input_path, output_path)
190
+
191
+ return {"job_id": job_id, "message": "Video analysis started"}
 
192
 
193
  @app.get("/status/{job_id}")
194
  async def get_status(job_id: str):
195
+ if job_id not in jobs: raise HTTPException(404, "Job not found")
196
+ return jobs[job_id]
 
 
 
 
 
 
 
197
 
198
  @app.get("/download/{job_id}")
199
+ async def download(job_id: str):
200
+ if job_id not in jobs or jobs[job_id]["status"] != "completed":
201
+ raise HTTPException(400, "File not ready or job not found")
202
+ return FileResponse(jobs[job_id]["output_path"], filename=f"analyzed_{job_id}.mp4")
 
 
 
 
 
 
 
 
203
 
204
  if __name__ == "__main__":
205
  import uvicorn
206
+ uvicorn.run(app, host="0.0.0.0", port=7860)