Upload 3 files
Browse files- Dockerfile +20 -0
- inference.py +103 -0
- main.py +55 -0
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM tensorflow/tensorflow:2.15.0-gpu
|
| 2 |
+
|
| 3 |
+
RUN apt-get update && apt-get install -y libgl1-mesa-glx libglib2.0-0 && rm -rf /var/lib/apt/lists/*
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Copy code
|
| 8 |
+
COPY . .
|
| 9 |
+
|
| 10 |
+
# Install dependencies
|
| 11 |
+
RUN pip install opencv-python fastapi uvicorn[standard] websockets
|
| 12 |
+
|
| 13 |
+
ENV MODEL_PATH="model"
|
| 14 |
+
ENV RESOLUTION="172"
|
| 15 |
+
|
| 16 |
+
# Mở port 8000
|
| 17 |
+
EXPOSE 8000
|
| 18 |
+
|
| 19 |
+
# Chạy Server FastAPI
|
| 20 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
inference.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference.py
|
| 2 |
+
import cv2
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import time
|
| 5 |
+
import base64
|
| 6 |
+
import datetime
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Load model 1 lần duy nhất khi import
|
| 10 |
+
MODEL_PATH = os.getenv('MODEL_PATH', 'model')
|
| 11 |
+
RESOLUTION = int(os.getenv('RESOLUTION', 172))
|
| 12 |
+
CONFIDENCE_THRESHOLD = 0.65
|
| 13 |
+
|
| 14 |
+
print(f"Loading MoViNet from {MODEL_PATH}...")
|
| 15 |
+
model = tf.saved_model.load(MODEL_PATH)
|
| 16 |
+
infer = model.signatures['serving_default']
|
| 17 |
+
print("Model loaded!")
|
| 18 |
+
|
| 19 |
+
def get_init_states():
|
| 20 |
+
dummy = tf.zeros([1, 1, RESOLUTION, RESOLUTION, 3], dtype=tf.float32)
|
| 21 |
+
return model.init_states(tf.shape(dummy))
|
| 22 |
+
|
| 23 |
+
class VideoProcessor:
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.running = False
|
| 26 |
+
|
| 27 |
+
def start_processing(self, rtsp_url, result_queue):
|
| 28 |
+
self.running = True
|
| 29 |
+
cap = cv2.VideoCapture(rtsp_url)
|
| 30 |
+
if not cap.isOpened():
|
| 31 |
+
result_queue.put({"error": "Cannot open RTSP URL"})
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
states = get_init_states()
|
| 35 |
+
|
| 36 |
+
# Logic quản lý sự kiện (Cooldown)
|
| 37 |
+
in_event = False
|
| 38 |
+
event_start_time = None
|
| 39 |
+
cooldown_counter = 0
|
| 40 |
+
COOLDOWN_LIMIT = 30 # Frames (~2-3s)
|
| 41 |
+
|
| 42 |
+
while self.running:
|
| 43 |
+
ret, frame = cap.read()
|
| 44 |
+
if not ret:
|
| 45 |
+
break
|
| 46 |
+
|
| 47 |
+
# 1. Inference
|
| 48 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 49 |
+
resized = tf.image.resize(rgb, [RESOLUTION, RESOLUTION])
|
| 50 |
+
input_tensor = tf.cast(resized, tf.float32) / 255.0
|
| 51 |
+
input_tensor = input_tensor[tf.newaxis, tf.newaxis, ...]
|
| 52 |
+
|
| 53 |
+
outputs = infer(image=input_tensor, **states)
|
| 54 |
+
logits = outputs['logits']
|
| 55 |
+
states = {k: v for k, v in outputs.items() if k != 'logits'}
|
| 56 |
+
probs = tf.nn.softmax(logits, axis=-1)[0]
|
| 57 |
+
|
| 58 |
+
fight_conf = float(probs[0])
|
| 59 |
+
norm_conf = float(probs[1])
|
| 60 |
+
is_violence = (fight_conf > norm_conf) and (fight_conf > CONFIDENCE_THRESHOLD)
|
| 61 |
+
|
| 62 |
+
# 2. Logic xử lý kết quả để gửi về
|
| 63 |
+
current_time = datetime.datetime.now()
|
| 64 |
+
msg = None
|
| 65 |
+
|
| 66 |
+
if is_violence:
|
| 67 |
+
cooldown_counter = 0
|
| 68 |
+
if not in_event:
|
| 69 |
+
in_event = True
|
| 70 |
+
event_start_time = current_time
|
| 71 |
+
|
| 72 |
+
# START: Gửi ảnh bằng chứng
|
| 73 |
+
small_frame = cv2.resize(frame, (640, 360))
|
| 74 |
+
_, buffer = cv2.imencode('.jpg', small_frame)
|
| 75 |
+
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
| 76 |
+
|
| 77 |
+
msg = {
|
| 78 |
+
"type": "START",
|
| 79 |
+
"timestamp": current_time.isoformat(),
|
| 80 |
+
"score": fight_conf,
|
| 81 |
+
"image": img_base64
|
| 82 |
+
}
|
| 83 |
+
else:
|
| 84 |
+
if in_event:
|
| 85 |
+
cooldown_counter += 1
|
| 86 |
+
if cooldown_counter >= COOLDOWN_LIMIT:
|
| 87 |
+
# END: Gửi thời lượng
|
| 88 |
+
duration = (current_time - event_start_time).total_seconds()
|
| 89 |
+
msg = {
|
| 90 |
+
"type": "END",
|
| 91 |
+
"timestamp": current_time.isoformat(),
|
| 92 |
+
"duration": duration
|
| 93 |
+
}
|
| 94 |
+
in_event = False
|
| 95 |
+
|
| 96 |
+
# Nếu có tin quan trọng thì đẩy vào hàng đợi gửi về Laptop
|
| 97 |
+
if msg:
|
| 98 |
+
result_queue.put(msg)
|
| 99 |
+
|
| 100 |
+
# Tùy chọn: Gửi Heartbeat mỗi 5s để biết model vẫn sống (nếu cần)
|
| 101 |
+
|
| 102 |
+
cap.release()
|
| 103 |
+
result_queue.put({"status": "Stream stopped"})
|
main.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# main.py
|
| 2 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 3 |
+
from inference import VideoProcessor
|
| 4 |
+
import threading
|
| 5 |
+
import queue
|
| 6 |
+
import asyncio
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
app = FastAPI()
|
| 10 |
+
|
| 11 |
+
@app.websocket("/ws/inference")
|
| 12 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 13 |
+
await websocket.accept()
|
| 14 |
+
processor = VideoProcessor()
|
| 15 |
+
result_queue = queue.Queue()
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
# 1. Nhận link RTSP từ Laptop
|
| 19 |
+
data = await websocket.receive_text()
|
| 20 |
+
config = json.loads(data)
|
| 21 |
+
rtsp_url = config.get("url")
|
| 22 |
+
|
| 23 |
+
print(f"Received request for: {rtsp_url}")
|
| 24 |
+
|
| 25 |
+
# 2. Chạy Model trong luồng riêng (Background Thread)
|
| 26 |
+
# Để không chặn WebSocket
|
| 27 |
+
process_thread = threading.Thread(
|
| 28 |
+
target=processor.start_processing,
|
| 29 |
+
args=(rtsp_url, result_queue)
|
| 30 |
+
)
|
| 31 |
+
process_thread.start()
|
| 32 |
+
|
| 33 |
+
# 3. Vòng lặp gửi kết quả về Laptop
|
| 34 |
+
while True:
|
| 35 |
+
# Kiểm tra queue xem có kết quả mới từ Model không
|
| 36 |
+
try:
|
| 37 |
+
# Non-blocking get
|
| 38 |
+
result = result_queue.get_nowait()
|
| 39 |
+
await websocket.send_json(result)
|
| 40 |
+
except queue.Empty:
|
| 41 |
+
# Nếu không có kết quả, ngủ xíu để đỡ tốn CPU
|
| 42 |
+
await asyncio.sleep(0.1)
|
| 43 |
+
|
| 44 |
+
# Kiểm tra nếu client ngắt kết nối thì dừng model
|
| 45 |
+
if not process_thread.is_alive():
|
| 46 |
+
await websocket.send_json({"status": "Processing finished"})
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
except WebSocketDisconnect:
|
| 50 |
+
print("Client disconnected")
|
| 51 |
+
processor.running = False # Ra lệnh dừng model
|
| 52 |
+
process_thread.join()
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Error: {e}")
|
| 55 |
+
processor.running = False
|