dataguychill commited on
Commit
164cfb7
·
verified ·
1 Parent(s): bbfd07b

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +20 -0
  2. inference.py +103 -0
  3. main.py +55 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM tensorflow/tensorflow:2.15.0-gpu
2
+
3
+ RUN apt-get update && apt-get install -y libgl1-mesa-glx libglib2.0-0 && rm -rf /var/lib/apt/lists/*
4
+
5
+ WORKDIR /app
6
+
7
+ # Copy code
8
+ COPY . .
9
+
10
+ # Install dependencies
11
+ RUN pip install opencv-python fastapi uvicorn[standard] websockets
12
+
13
+ ENV MODEL_PATH="model"
14
+ ENV RESOLUTION="172"
15
+
16
+ # Mở port 8000
17
+ EXPOSE 8000
18
+
19
+ # Chạy Server FastAPI
20
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
inference.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import cv2
3
+ import tensorflow as tf
4
+ import time
5
+ import base64
6
+ import datetime
7
+ import os
8
+
9
+ # Load model 1 lần duy nhất khi import
10
+ MODEL_PATH = os.getenv('MODEL_PATH', 'model')
11
+ RESOLUTION = int(os.getenv('RESOLUTION', 172))
12
+ CONFIDENCE_THRESHOLD = 0.65
13
+
14
+ print(f"Loading MoViNet from {MODEL_PATH}...")
15
+ model = tf.saved_model.load(MODEL_PATH)
16
+ infer = model.signatures['serving_default']
17
+ print("Model loaded!")
18
+
19
+ def get_init_states():
20
+ dummy = tf.zeros([1, 1, RESOLUTION, RESOLUTION, 3], dtype=tf.float32)
21
+ return model.init_states(tf.shape(dummy))
22
+
23
+ class VideoProcessor:
24
+ def __init__(self):
25
+ self.running = False
26
+
27
+ def start_processing(self, rtsp_url, result_queue):
28
+ self.running = True
29
+ cap = cv2.VideoCapture(rtsp_url)
30
+ if not cap.isOpened():
31
+ result_queue.put({"error": "Cannot open RTSP URL"})
32
+ return
33
+
34
+ states = get_init_states()
35
+
36
+ # Logic quản lý sự kiện (Cooldown)
37
+ in_event = False
38
+ event_start_time = None
39
+ cooldown_counter = 0
40
+ COOLDOWN_LIMIT = 30 # Frames (~2-3s)
41
+
42
+ while self.running:
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+
47
+ # 1. Inference
48
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
+ resized = tf.image.resize(rgb, [RESOLUTION, RESOLUTION])
50
+ input_tensor = tf.cast(resized, tf.float32) / 255.0
51
+ input_tensor = input_tensor[tf.newaxis, tf.newaxis, ...]
52
+
53
+ outputs = infer(image=input_tensor, **states)
54
+ logits = outputs['logits']
55
+ states = {k: v for k, v in outputs.items() if k != 'logits'}
56
+ probs = tf.nn.softmax(logits, axis=-1)[0]
57
+
58
+ fight_conf = float(probs[0])
59
+ norm_conf = float(probs[1])
60
+ is_violence = (fight_conf > norm_conf) and (fight_conf > CONFIDENCE_THRESHOLD)
61
+
62
+ # 2. Logic xử lý kết quả để gửi về
63
+ current_time = datetime.datetime.now()
64
+ msg = None
65
+
66
+ if is_violence:
67
+ cooldown_counter = 0
68
+ if not in_event:
69
+ in_event = True
70
+ event_start_time = current_time
71
+
72
+ # START: Gửi ảnh bằng chứng
73
+ small_frame = cv2.resize(frame, (640, 360))
74
+ _, buffer = cv2.imencode('.jpg', small_frame)
75
+ img_base64 = base64.b64encode(buffer).decode('utf-8')
76
+
77
+ msg = {
78
+ "type": "START",
79
+ "timestamp": current_time.isoformat(),
80
+ "score": fight_conf,
81
+ "image": img_base64
82
+ }
83
+ else:
84
+ if in_event:
85
+ cooldown_counter += 1
86
+ if cooldown_counter >= COOLDOWN_LIMIT:
87
+ # END: Gửi thời lượng
88
+ duration = (current_time - event_start_time).total_seconds()
89
+ msg = {
90
+ "type": "END",
91
+ "timestamp": current_time.isoformat(),
92
+ "duration": duration
93
+ }
94
+ in_event = False
95
+
96
+ # Nếu có tin quan trọng thì đẩy vào hàng đợi gửi về Laptop
97
+ if msg:
98
+ result_queue.put(msg)
99
+
100
+ # Tùy chọn: Gửi Heartbeat mỗi 5s để biết model vẫn sống (nếu cần)
101
+
102
+ cap.release()
103
+ result_queue.put({"status": "Stream stopped"})
main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
3
+ from inference import VideoProcessor
4
+ import threading
5
+ import queue
6
+ import asyncio
7
+ import json
8
+
9
+ app = FastAPI()
10
+
11
+ @app.websocket("/ws/inference")
12
+ async def websocket_endpoint(websocket: WebSocket):
13
+ await websocket.accept()
14
+ processor = VideoProcessor()
15
+ result_queue = queue.Queue()
16
+
17
+ try:
18
+ # 1. Nhận link RTSP từ Laptop
19
+ data = await websocket.receive_text()
20
+ config = json.loads(data)
21
+ rtsp_url = config.get("url")
22
+
23
+ print(f"Received request for: {rtsp_url}")
24
+
25
+ # 2. Chạy Model trong luồng riêng (Background Thread)
26
+ # Để không chặn WebSocket
27
+ process_thread = threading.Thread(
28
+ target=processor.start_processing,
29
+ args=(rtsp_url, result_queue)
30
+ )
31
+ process_thread.start()
32
+
33
+ # 3. Vòng lặp gửi kết quả về Laptop
34
+ while True:
35
+ # Kiểm tra queue xem có kết quả mới từ Model không
36
+ try:
37
+ # Non-blocking get
38
+ result = result_queue.get_nowait()
39
+ await websocket.send_json(result)
40
+ except queue.Empty:
41
+ # Nếu không có kết quả, ngủ xíu để đỡ tốn CPU
42
+ await asyncio.sleep(0.1)
43
+
44
+ # Kiểm tra nếu client ngắt kết nối thì dừng model
45
+ if not process_thread.is_alive():
46
+ await websocket.send_json({"status": "Processing finished"})
47
+ break
48
+
49
+ except WebSocketDisconnect:
50
+ print("Client disconnected")
51
+ processor.running = False # Ra lệnh dừng model
52
+ process_thread.join()
53
+ except Exception as e:
54
+ print(f"Error: {e}")
55
+ processor.running = False