Spaces:
Running
Running
Upload 53 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +14 -0
- web_demo/__pycache__/yolo_detection.cpython-311.pyc +0 -0
- web_demo/app.py +207 -0
- web_demo/models/c3d.pickle +3 -0
- web_demo/models/epoch_80000.pt +3 -0
- web_demo/models/yolo_my_model.pt +3 -0
- web_demo/network/MFNET.py +278 -0
- web_demo/network/TorchUtils.py +284 -0
- web_demo/network/__init__.py +0 -0
- web_demo/network/__pycache__/MFNET.cpython-311.pyc +0 -0
- web_demo/network/__pycache__/TorchUtils.cpython-311.pyc +0 -0
- web_demo/network/__pycache__/__init__.cpython-311.pyc +0 -0
- web_demo/network/__pycache__/anomaly_detector_model.cpython-311.pyc +0 -0
- web_demo/network/__pycache__/c3d.cpython-311.pyc +0 -0
- web_demo/network/__pycache__/resnet.cpython-311.pyc +0 -0
- web_demo/network/anomaly_detector_model.py +142 -0
- web_demo/network/c3d.py +129 -0
- web_demo/network/resnet.py +232 -0
- web_demo/requirements.txt +7 -0
- web_demo/static/css/style.css +112 -0
- web_demo/static/js/main.js +108 -0
- web_demo/static/script.js +41 -0
- web_demo/static/videos/Abuse.mp4 +3 -0
- web_demo/static/videos/Arrest.mp4 +3 -0
- web_demo/static/videos/Arson.mp4 +3 -0
- web_demo/static/videos/Assault.mp4 +3 -0
- web_demo/static/videos/Burglary.mp4 +3 -0
- web_demo/static/videos/Explosion.mp4 +3 -0
- web_demo/static/videos/Fighting.mp4 +3 -0
- web_demo/static/videos/Normal.mp4 +3 -0
- web_demo/static/videos/RoadAccidents.mp4 +3 -0
- web_demo/static/videos/Robbery.mp4 +3 -0
- web_demo/static/videos/Shooting.mp4 +3 -0
- web_demo/static/videos/Shoplifting.mp4 +3 -0
- web_demo/static/videos/Stealing.mp4 +3 -0
- web_demo/static/videos/Vandalism.mp4 +3 -0
- web_demo/templates/index.html +66 -0
- web_demo/utils/__init__.py +0 -0
- web_demo/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- web_demo/utils/__pycache__/callbacks.cpython-311.pyc +0 -0
- web_demo/utils/__pycache__/functional_video.cpython-311.pyc +0 -0
- web_demo/utils/__pycache__/load_model.cpython-311.pyc +0 -0
- web_demo/utils/__pycache__/stack.cpython-311.pyc +0 -0
- web_demo/utils/__pycache__/transforms_video.cpython-311.pyc +0 -0
- web_demo/utils/__pycache__/types.cpython-311.pyc +0 -0
- web_demo/utils/__pycache__/utils.cpython-311.pyc +0 -0
- web_demo/utils/callbacks.py +197 -0
- web_demo/utils/functional_video.py +104 -0
- web_demo/utils/load_model.py +114 -0
- web_demo/utils/stack.py +33 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
web_demo/static/videos/Abuse.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
web_demo/static/videos/Arrest.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
web_demo/static/videos/Arson.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
web_demo/static/videos/Assault.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
web_demo/static/videos/Burglary.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
web_demo/static/videos/Explosion.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
web_demo/static/videos/Fighting.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
web_demo/static/videos/Normal.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
web_demo/static/videos/RoadAccidents.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
web_demo/static/videos/Robbery.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
web_demo/static/videos/Shooting.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
web_demo/static/videos/Shoplifting.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
web_demo/static/videos/Stealing.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
web_demo/static/videos/Vandalism.mp4 filter=lfs diff=lfs merge=lfs -text
|
web_demo/__pycache__/yolo_detection.cpython-311.pyc
ADDED
|
Binary file (3.89 kB). View file
|
|
|
web_demo/app.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import threading
|
| 8 |
+
import base64
|
| 9 |
+
from werkzeug.utils import secure_filename
|
| 10 |
+
|
| 11 |
+
from flask import Flask, render_template, Response, request, jsonify
|
| 12 |
+
from flask_socketio import SocketIO
|
| 13 |
+
|
| 14 |
+
# Important: Make sure your custom utility scripts are accessible
|
| 15 |
+
from utils.load_model import load_models
|
| 16 |
+
from utils.utils import build_transforms
|
| 17 |
+
from network.TorchUtils import get_torch_device
|
| 18 |
+
from yolo_detection import analyze_video_with_yolo
|
| 19 |
+
|
| 20 |
+
# ---- App Setup ----
|
| 21 |
+
app = Flask(__name__)
|
| 22 |
+
app.config['SECRET_KEY'] = 'your_secret_key!'
|
| 23 |
+
|
| 24 |
+
# ADDED: Configuration for uploaded files
|
| 25 |
+
UPLOAD_FOLDER = 'uploads'
|
| 26 |
+
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 27 |
+
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
| 28 |
+
|
| 29 |
+
socketio = SocketIO(app, async_mode='eventlet')
|
| 30 |
+
|
| 31 |
+
# ---- Global Config & Model Loading ----
|
| 32 |
+
print("[INFO] Loading models...")
|
| 33 |
+
DEVICE = get_torch_device()
|
| 34 |
+
FEATURE_EXTRACTOR_PATH = r"S:\\ano_dec_pro\\AnomalyDetectionCVPR2018-Pytorch\\pretrained\\c3d.pickle"
|
| 35 |
+
AD_MODEL_PATH = r"S:\\ano_dec_pro\\AnomalyDetectionCVPR2018-Pytorch\\exps\\c3d\\models\\epoch_80000.pt"
|
| 36 |
+
YOLO_MODEL_PATH = r"S:\\ano_dec_pro\\AnomalyDetectionCVPR2018-Pytorch\\yolo_my_model.pt"
|
| 37 |
+
SAVE_DIR = "outputs/anomaly_frames"
|
| 38 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
anomaly_detector, feature_extractor = load_models(
|
| 41 |
+
FEATURE_EXTRACTOR_PATH, AD_MODEL_PATH, features_method="c3d", device=DEVICE
|
| 42 |
+
)
|
| 43 |
+
feature_extractor.eval()
|
| 44 |
+
anomaly_detector.eval()
|
| 45 |
+
TRANSFORMS = build_transforms(mode="c3d")
|
| 46 |
+
ANOMALY_THRESHOLD = 0.4
|
| 47 |
+
print("[INFO] Models loaded successfully.")
|
| 48 |
+
|
| 49 |
+
VIDEO_PATHS = {
|
| 50 |
+
"Abuse": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Abuse.mp4",
|
| 51 |
+
"Arrest": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Arrest.mp4",
|
| 52 |
+
"Arson": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Arson.mp4",
|
| 53 |
+
"Assault": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Assault.mp4",
|
| 54 |
+
"Burglary": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Burglary.mp4",
|
| 55 |
+
"Explosion": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Explosion.mp4",
|
| 56 |
+
"Fighting": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Fighting.mp4",
|
| 57 |
+
"RoadAccidents": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\RoadAccidents.mp4",
|
| 58 |
+
"Robbery": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Robbery.mp4",
|
| 59 |
+
"Shooting": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Shooting.mp4",
|
| 60 |
+
"Shoplifting": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Shoplifting.mp4",
|
| 61 |
+
"Stealing": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Stealing.mp4",
|
| 62 |
+
"Vandalism": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Vandalism.mp4",
|
| 63 |
+
"Normal": r"S:\\ano_dec_pro\AnomalyDetectionCVPR2018-Pytorch\web_demo\static\\videos\\Normal.mp4"
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# --- Threading control ---
|
| 67 |
+
thread = None
|
| 68 |
+
thread_lock = threading.Lock()
|
| 69 |
+
stop_event = threading.Event()
|
| 70 |
+
|
| 71 |
+
# (The `smooth_score` and `video_processing_task` functions remain unchanged from the previous version)
|
| 72 |
+
def smooth_score(scores, new_score, window=5):
|
| 73 |
+
scores.append(new_score)
|
| 74 |
+
if len(scores) > window:
|
| 75 |
+
scores.pop(0)
|
| 76 |
+
return float(np.mean(scores))
|
| 77 |
+
|
| 78 |
+
def video_processing_task(video_path):
|
| 79 |
+
global thread
|
| 80 |
+
try:
|
| 81 |
+
cap = cv2.VideoCapture(video_path)
|
| 82 |
+
if not cap.isOpened():
|
| 83 |
+
socketio.emit('processing_error', {'error': f'Could not open video file.'})
|
| 84 |
+
return
|
| 85 |
+
frame_buffer = []
|
| 86 |
+
last_save_time = 0
|
| 87 |
+
recent_scores = []
|
| 88 |
+
FRAME_SKIP = 4
|
| 89 |
+
frame_count = 0
|
| 90 |
+
while cap.isOpened() and not stop_event.is_set():
|
| 91 |
+
socketio.sleep(0.001)
|
| 92 |
+
ret, frame = cap.read()
|
| 93 |
+
if not ret: break
|
| 94 |
+
frame_count += 1
|
| 95 |
+
if frame_count % (FRAME_SKIP + 1) != 0: continue
|
| 96 |
+
frame_buffer.append(frame.copy())
|
| 97 |
+
if len(frame_buffer) == 16:
|
| 98 |
+
frames_resized = [cv2.resize(f, (112, 112)) for f in frame_buffer]
|
| 99 |
+
clip_np = np.array(frames_resized, dtype=np.uint8)
|
| 100 |
+
clip_torch = torch.from_numpy(clip_np)
|
| 101 |
+
clip_torch = TRANSFORMS(clip_torch)
|
| 102 |
+
clip_torch = clip_torch.unsqueeze(0).to(DEVICE)
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
features = feature_extractor(clip_torch).detach()
|
| 105 |
+
score_tensor = anomaly_detector(features).detach()
|
| 106 |
+
score = float(score_tensor.view(-1)[0].item())
|
| 107 |
+
score = smooth_score(recent_scores, score)
|
| 108 |
+
score = float(np.clip(score, 0, 1))
|
| 109 |
+
socketio.emit('update_graph', {'score': score})
|
| 110 |
+
if score > ANOMALY_THRESHOLD and (time.time() - last_save_time) >= 30:
|
| 111 |
+
last_save_time = time.time()
|
| 112 |
+
socketio.emit('update_status', {'status': 'Anomaly detected! Saving clip...'})
|
| 113 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 114 |
+
clip_dir = os.path.join(SAVE_DIR, f"anomaly_{timestamp}")
|
| 115 |
+
os.makedirs(clip_dir, exist_ok=True)
|
| 116 |
+
first_frame_path = os.path.join(clip_dir, "anomaly_frame.jpg")
|
| 117 |
+
cv2.imwrite(first_frame_path, frame_buffer[0])
|
| 118 |
+
try:
|
| 119 |
+
yolo_result = analyze_video_with_yolo(first_frame_path, model_path=YOLO_MODEL_PATH, return_class=True)
|
| 120 |
+
socketio.emit('update_yolo_text', {'text': f"YOLO Class: {yolo_result}"})
|
| 121 |
+
_, buffer = cv2.imencode('.jpg', frame_buffer[0])
|
| 122 |
+
b64_str = base64.b64encode(buffer).decode('utf-8')
|
| 123 |
+
socketio.emit('update_yolo_image', {'image_data': b64_str})
|
| 124 |
+
except Exception as e:
|
| 125 |
+
socketio.emit('update_yolo_text', {'text': f'YOLO Error: {e}'})
|
| 126 |
+
frame_buffer.clear()
|
| 127 |
+
cap.release()
|
| 128 |
+
if not stop_event.is_set():
|
| 129 |
+
socketio.emit('processing_finished', {'message': 'Video finished.'})
|
| 130 |
+
finally:
|
| 131 |
+
with thread_lock:
|
| 132 |
+
thread = None
|
| 133 |
+
stop_event.clear()
|
| 134 |
+
|
| 135 |
+
@app.route('/')
|
| 136 |
+
def index():
|
| 137 |
+
return render_template('index.html', anomaly_names=VIDEO_PATHS.keys())
|
| 138 |
+
|
| 139 |
+
# ADDED: New route for handling video uploads
|
| 140 |
+
@app.route('/upload', methods=['POST'])
|
| 141 |
+
def upload_file():
|
| 142 |
+
if 'video' not in request.files:
|
| 143 |
+
return jsonify({'error': 'No video file found'}), 400
|
| 144 |
+
file = request.files['video']
|
| 145 |
+
if file.filename == '':
|
| 146 |
+
return jsonify({'error': 'No video file selected'}), 400
|
| 147 |
+
if file:
|
| 148 |
+
filename = secure_filename(file.filename)
|
| 149 |
+
# Add a timestamp to make filenames unique
|
| 150 |
+
unique_filename = f"{datetime.now().strftime('%Y%m%d%HM%S')}_{filename}"
|
| 151 |
+
save_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
|
| 152 |
+
file.save(save_path)
|
| 153 |
+
return jsonify({'success': True, 'filename': unique_filename})
|
| 154 |
+
return jsonify({'error': 'File upload failed'}), 500
|
| 155 |
+
|
| 156 |
+
# MODIFIED: This route now streams both demo and uploaded videos
|
| 157 |
+
@app.route('/video_stream/<source>/<filename>')
|
| 158 |
+
def video_stream(source, filename):
|
| 159 |
+
if source == 'demo':
|
| 160 |
+
path = VIDEO_PATHS.get(filename)
|
| 161 |
+
elif source == 'upload':
|
| 162 |
+
path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 163 |
+
else:
|
| 164 |
+
return "Invalid source", 404
|
| 165 |
+
|
| 166 |
+
if not path or not os.path.exists(path):
|
| 167 |
+
return "Video not found", 404
|
| 168 |
+
|
| 169 |
+
def generate():
|
| 170 |
+
with open(path, "rb") as f:
|
| 171 |
+
while chunk := f.read(1024 * 1024):
|
| 172 |
+
yield chunk
|
| 173 |
+
|
| 174 |
+
return Response(generate(), mimetype="video/mp4")
|
| 175 |
+
|
| 176 |
+
@socketio.on('start_processing')
|
| 177 |
+
def handle_start_processing(data):
|
| 178 |
+
global thread
|
| 179 |
+
with thread_lock:
|
| 180 |
+
if thread is None:
|
| 181 |
+
stop_event.clear()
|
| 182 |
+
source = data.get('source')
|
| 183 |
+
filename = data.get('filename')
|
| 184 |
+
video_path = None
|
| 185 |
+
|
| 186 |
+
if source == 'demo':
|
| 187 |
+
video_path = VIDEO_PATHS.get(filename)
|
| 188 |
+
elif source == 'upload':
|
| 189 |
+
video_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 190 |
+
|
| 191 |
+
if video_path and os.path.exists(video_path):
|
| 192 |
+
print(f"[INFO] Starting processing for {filename} from {source}")
|
| 193 |
+
thread = socketio.start_background_task(target=video_processing_task, video_path=video_path)
|
| 194 |
+
else:
|
| 195 |
+
socketio.emit('processing_error', {'error': f'Video file not found!'})
|
| 196 |
+
|
| 197 |
+
@socketio.on('reset_system')
|
| 198 |
+
def handle_reset():
|
| 199 |
+
global thread
|
| 200 |
+
with thread_lock:
|
| 201 |
+
if thread is not None:
|
| 202 |
+
stop_event.set()
|
| 203 |
+
socketio.emit('system_reset_confirm')
|
| 204 |
+
|
| 205 |
+
if __name__ == '__main__':
|
| 206 |
+
print("[INFO] Starting Flask server...")
|
| 207 |
+
socketio.run(app, debug=True)
|
web_demo/models/c3d.pickle
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e082d1890be04df0600aebae68f8687f5f41ba7590d2556edaa9ca49513cadff
|
| 3 |
+
size 319966434
|
web_demo/models/epoch_80000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9cbffe0b8831ed2c5ac82be4b40f10699b1a27fba84226a40161c6a381832510
|
| 3 |
+
size 8460133
|
web_demo/models/yolo_my_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ef4636cec13eb6e8f4f08aa10430acd25dabe394c0aadf97ad13e8f2c34074b6
|
| 3 |
+
size 19187290
|
web_demo/network/MFNET.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Author: Yunpeng Chen."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BN_AC_CONV3D(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
num_in,
|
| 14 |
+
num_filter,
|
| 15 |
+
kernel=(1, 1, 1),
|
| 16 |
+
pad=(0, 0, 0),
|
| 17 |
+
stride=(1, 1, 1),
|
| 18 |
+
g=1,
|
| 19 |
+
bias=False,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.bn = nn.BatchNorm3d(num_in)
|
| 23 |
+
self.relu = nn.ReLU(inplace=True)
|
| 24 |
+
self.conv = nn.Conv3d(
|
| 25 |
+
num_in,
|
| 26 |
+
num_filter,
|
| 27 |
+
kernel_size=kernel,
|
| 28 |
+
padding=pad,
|
| 29 |
+
stride=stride,
|
| 30 |
+
groups=g,
|
| 31 |
+
bias=bias,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
h = self.relu(self.bn(x))
|
| 36 |
+
h = self.conv(h)
|
| 37 |
+
return h
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MF_UNIT(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
num_in,
|
| 44 |
+
num_mid,
|
| 45 |
+
num_out,
|
| 46 |
+
g=1,
|
| 47 |
+
stride=(1, 1, 1),
|
| 48 |
+
first_block=False,
|
| 49 |
+
use_3d=True,
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
num_ix = int(num_mid / 4)
|
| 53 |
+
kt, pt = (3, 1) if use_3d else (1, 0)
|
| 54 |
+
# prepare input
|
| 55 |
+
self.conv_i1 = BN_AC_CONV3D(
|
| 56 |
+
num_in=num_in, num_filter=num_ix, kernel=(1, 1, 1), pad=(0, 0, 0)
|
| 57 |
+
)
|
| 58 |
+
self.conv_i2 = BN_AC_CONV3D(
|
| 59 |
+
num_in=num_ix, num_filter=num_in, kernel=(1, 1, 1), pad=(0, 0, 0)
|
| 60 |
+
)
|
| 61 |
+
# main part
|
| 62 |
+
self.conv_m1 = BN_AC_CONV3D(
|
| 63 |
+
num_in=num_in,
|
| 64 |
+
num_filter=num_mid,
|
| 65 |
+
kernel=(kt, 3, 3),
|
| 66 |
+
pad=(pt, 1, 1),
|
| 67 |
+
stride=stride,
|
| 68 |
+
g=g,
|
| 69 |
+
)
|
| 70 |
+
if first_block:
|
| 71 |
+
self.conv_m2 = BN_AC_CONV3D(
|
| 72 |
+
num_in=num_mid, num_filter=num_out, kernel=(1, 1, 1), pad=(0, 0, 0)
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
self.conv_m2 = BN_AC_CONV3D(
|
| 76 |
+
num_in=num_mid, num_filter=num_out, kernel=(1, 3, 3), pad=(0, 1, 1), g=g
|
| 77 |
+
)
|
| 78 |
+
# adapter
|
| 79 |
+
if first_block:
|
| 80 |
+
self.conv_w1 = BN_AC_CONV3D(
|
| 81 |
+
num_in=num_in,
|
| 82 |
+
num_filter=num_out,
|
| 83 |
+
kernel=(1, 1, 1),
|
| 84 |
+
pad=(0, 0, 0),
|
| 85 |
+
stride=stride,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
h = self.conv_i1(x)
|
| 90 |
+
x_in = x + self.conv_i2(h)
|
| 91 |
+
|
| 92 |
+
h = self.conv_m1(x_in)
|
| 93 |
+
h = self.conv_m2(h)
|
| 94 |
+
|
| 95 |
+
if hasattr(self, "conv_w1"):
|
| 96 |
+
x = self.conv_w1(x)
|
| 97 |
+
|
| 98 |
+
return h + x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class MFNET_3D(nn.Module):
|
| 102 |
+
"""Original code: https://github.com/cypw/PyTorch-MFNet."""
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
**_kwargs,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
groups = 16
|
| 111 |
+
k_sec = {2: 3, 3: 4, 4: 6, 5: 3}
|
| 112 |
+
|
| 113 |
+
# conv1 - x224 (x16)
|
| 114 |
+
conv1_num_out = 16
|
| 115 |
+
self.conv1 = nn.Sequential(
|
| 116 |
+
OrderedDict(
|
| 117 |
+
[
|
| 118 |
+
(
|
| 119 |
+
"conv",
|
| 120 |
+
nn.Conv3d(
|
| 121 |
+
3,
|
| 122 |
+
conv1_num_out,
|
| 123 |
+
kernel_size=(3, 5, 5),
|
| 124 |
+
padding=(1, 2, 2),
|
| 125 |
+
stride=(1, 2, 2),
|
| 126 |
+
bias=False,
|
| 127 |
+
),
|
| 128 |
+
),
|
| 129 |
+
("bn", nn.BatchNorm3d(conv1_num_out)),
|
| 130 |
+
("relu", nn.ReLU(inplace=True)),
|
| 131 |
+
]
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
self.maxpool = nn.MaxPool3d(
|
| 135 |
+
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# conv2 - x56 (x8)
|
| 139 |
+
num_mid = 96
|
| 140 |
+
conv2_num_out = 96
|
| 141 |
+
self.conv2 = nn.Sequential(
|
| 142 |
+
OrderedDict(
|
| 143 |
+
[
|
| 144 |
+
(
|
| 145 |
+
"B%02d" % i,
|
| 146 |
+
MF_UNIT(
|
| 147 |
+
num_in=conv1_num_out if i == 1 else conv2_num_out,
|
| 148 |
+
num_mid=num_mid,
|
| 149 |
+
num_out=conv2_num_out,
|
| 150 |
+
stride=(2, 1, 1) if i == 1 else (1, 1, 1),
|
| 151 |
+
g=groups,
|
| 152 |
+
first_block=(i == 1),
|
| 153 |
+
),
|
| 154 |
+
)
|
| 155 |
+
for i in range(1, k_sec[2] + 1)
|
| 156 |
+
]
|
| 157 |
+
)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# conv3 - x28 (x8)
|
| 161 |
+
num_mid *= 2
|
| 162 |
+
conv3_num_out = 2 * conv2_num_out
|
| 163 |
+
self.conv3 = nn.Sequential(
|
| 164 |
+
OrderedDict(
|
| 165 |
+
[
|
| 166 |
+
(
|
| 167 |
+
"B%02d" % i,
|
| 168 |
+
MF_UNIT(
|
| 169 |
+
num_in=conv2_num_out if i == 1 else conv3_num_out,
|
| 170 |
+
num_mid=num_mid,
|
| 171 |
+
num_out=conv3_num_out,
|
| 172 |
+
stride=(1, 2, 2) if i == 1 else (1, 1, 1),
|
| 173 |
+
g=groups,
|
| 174 |
+
first_block=(i == 1),
|
| 175 |
+
),
|
| 176 |
+
)
|
| 177 |
+
for i in range(1, k_sec[3] + 1)
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# conv4 - x14 (x8)
|
| 183 |
+
num_mid *= 2
|
| 184 |
+
conv4_num_out = 2 * conv3_num_out
|
| 185 |
+
self.conv4 = nn.Sequential(
|
| 186 |
+
OrderedDict(
|
| 187 |
+
[
|
| 188 |
+
(
|
| 189 |
+
"B%02d" % i,
|
| 190 |
+
MF_UNIT(
|
| 191 |
+
num_in=conv3_num_out if i == 1 else conv4_num_out,
|
| 192 |
+
num_mid=num_mid,
|
| 193 |
+
num_out=conv4_num_out,
|
| 194 |
+
stride=(1, 2, 2) if i == 1 else (1, 1, 1),
|
| 195 |
+
g=groups,
|
| 196 |
+
first_block=(i == 1),
|
| 197 |
+
),
|
| 198 |
+
)
|
| 199 |
+
for i in range(1, k_sec[4] + 1)
|
| 200 |
+
]
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# conv5 - x7 (x8)
|
| 205 |
+
num_mid *= 2
|
| 206 |
+
conv5_num_out = 2 * conv4_num_out
|
| 207 |
+
self.conv5 = nn.Sequential(
|
| 208 |
+
OrderedDict(
|
| 209 |
+
[
|
| 210 |
+
(
|
| 211 |
+
"B%02d" % i,
|
| 212 |
+
MF_UNIT(
|
| 213 |
+
num_in=conv4_num_out if i == 1 else conv5_num_out,
|
| 214 |
+
num_mid=num_mid,
|
| 215 |
+
num_out=conv5_num_out,
|
| 216 |
+
stride=(1, 2, 2) if i == 1 else (1, 1, 1),
|
| 217 |
+
g=groups,
|
| 218 |
+
first_block=(i == 1),
|
| 219 |
+
),
|
| 220 |
+
)
|
| 221 |
+
for i in range(1, k_sec[5] + 1)
|
| 222 |
+
]
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# final
|
| 227 |
+
self.tail = nn.Sequential(
|
| 228 |
+
OrderedDict(
|
| 229 |
+
[("bn", nn.BatchNorm3d(conv5_num_out)), ("relu", nn.ReLU(inplace=True))]
|
| 230 |
+
)
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.globalpool = nn.Sequential(
|
| 234 |
+
OrderedDict(
|
| 235 |
+
[
|
| 236 |
+
("avg", nn.AvgPool3d(kernel_size=(1, 7, 7), stride=(1, 1, 1))),
|
| 237 |
+
("dropout", nn.Dropout(p=0.5)), # only for fine-tuning
|
| 238 |
+
]
|
| 239 |
+
)
|
| 240 |
+
)
|
| 241 |
+
# self.classifier = nn.Linear(conv5_num_out, num_classes)
|
| 242 |
+
|
| 243 |
+
def forward(self, x):
|
| 244 |
+
# assert x.shape[2] == 16
|
| 245 |
+
|
| 246 |
+
h = self.conv1(x) # x224 -> x112
|
| 247 |
+
h = self.maxpool(h) # x112 -> x56
|
| 248 |
+
|
| 249 |
+
h = self.conv2(h) # x56 -> x56
|
| 250 |
+
h = self.conv3(h) # x56 -> x28
|
| 251 |
+
h = self.conv4(h) # x28 -> x14
|
| 252 |
+
h = self.conv5(h) # x14 -> x7
|
| 253 |
+
|
| 254 |
+
h = self.tail(h)
|
| 255 |
+
h = self.globalpool(h)
|
| 256 |
+
|
| 257 |
+
h = h.view(h.shape[0], -1)
|
| 258 |
+
# h = self.classifier(h)
|
| 259 |
+
# h = h.view(h.shape[0], -1)
|
| 260 |
+
return h
|
| 261 |
+
|
| 262 |
+
def load_state(self, state_dict):
|
| 263 |
+
# customized partialy load function
|
| 264 |
+
checkpoint = torch.load(state_dict, map_location=torch.device("cpu"))
|
| 265 |
+
state_dict = checkpoint["state_dict"]
|
| 266 |
+
net_state_keys = list(self.state_dict().keys())
|
| 267 |
+
for name, param in state_dict.items():
|
| 268 |
+
name = name.replace("module.", "")
|
| 269 |
+
if name in self.state_dict().keys():
|
| 270 |
+
dst_param_shape = self.state_dict()[name].shape
|
| 271 |
+
if param.shape == dst_param_shape:
|
| 272 |
+
self.state_dict()[name].copy_(param.view(dst_param_shape))
|
| 273 |
+
net_state_keys.remove(name)
|
| 274 |
+
# indicating missed keys
|
| 275 |
+
if net_state_keys:
|
| 276 |
+
logging.warning(f">> Failed to load: {net_state_keys}")
|
| 277 |
+
|
| 278 |
+
return self
|
web_demo/network/TorchUtils.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Written by Eitan Kosman."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from typing import List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
from torch.optim import Optimizer
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
|
| 13 |
+
from utils.callbacks import Callback
|
| 14 |
+
from utils.types import Device
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from network.anomaly_detector_model import AnomalyDetector
|
| 18 |
+
|
| 19 |
+
# Use safe_globals context
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_torch_device() -> Device:
|
| 24 |
+
"""
|
| 25 |
+
Retrieves the device to run torch models, with preferability to GPU (denoted as cuda by torch)
|
| 26 |
+
Returns: Device to run the models
|
| 27 |
+
"""
|
| 28 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_model(model_path: str) -> nn.Module:
|
| 32 |
+
"""Loads a Pytorch model (CPU compatible, PyTorch >=2.6)."""
|
| 33 |
+
logging.info(f"Load the model from: {model_path}")
|
| 34 |
+
|
| 35 |
+
from network.anomaly_detector_model import AnomalyDetector
|
| 36 |
+
|
| 37 |
+
# Wrap torch.load with safe_globals and weights_only=False
|
| 38 |
+
with torch.serialization.safe_globals([AnomalyDetector]):
|
| 39 |
+
model = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 40 |
+
|
| 41 |
+
logging.info(model)
|
| 42 |
+
return model
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TorchModel(nn.Module):
|
| 47 |
+
"""Wrapper class for a torch model to make it comfortable to train and load
|
| 48 |
+
models."""
|
| 49 |
+
|
| 50 |
+
def __init__(self, model: nn.Module) -> None:
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.device = get_torch_device()
|
| 53 |
+
self.iteration = 0
|
| 54 |
+
self.model = model
|
| 55 |
+
self.is_data_parallel = False
|
| 56 |
+
self.callbacks = []
|
| 57 |
+
|
| 58 |
+
def register_callback(self, callback_fn: Callback) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Register a callback to be called after each evaluation run
|
| 61 |
+
Args:
|
| 62 |
+
callback_fn: a callable that accepts 2 inputs (output, target)
|
| 63 |
+
- output is the model's output
|
| 64 |
+
- target is the values of the target variable
|
| 65 |
+
"""
|
| 66 |
+
self.callbacks.append(callback_fn)
|
| 67 |
+
|
| 68 |
+
def data_parallel(self):
|
| 69 |
+
"""Transfers the model to data parallel mode."""
|
| 70 |
+
self.is_data_parallel = True
|
| 71 |
+
if not isinstance(self.model, torch.nn.DataParallel):
|
| 72 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1])
|
| 73 |
+
|
| 74 |
+
return self
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def load_model(cls, model_path: str):
|
| 78 |
+
"""
|
| 79 |
+
Loads a pickled model
|
| 80 |
+
Args:
|
| 81 |
+
model_path: path to the pickled model
|
| 82 |
+
|
| 83 |
+
Returns: TorchModel class instance wrapping the provided model
|
| 84 |
+
"""
|
| 85 |
+
return cls(load_model(model_path))
|
| 86 |
+
|
| 87 |
+
def notify_callbacks(self, notification, *args, **kwargs) -> None:
|
| 88 |
+
"""Calls all callbacks registered with this class.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
notification: The type of notification to be called.
|
| 92 |
+
"""
|
| 93 |
+
for callback in self.callbacks:
|
| 94 |
+
try:
|
| 95 |
+
method = getattr(callback, notification)
|
| 96 |
+
method(*args, **kwargs)
|
| 97 |
+
except (AttributeError, TypeError) as e:
|
| 98 |
+
logging.error(
|
| 99 |
+
f"callback {callback.__class__.__name__} doesn't fully implement the required interface {e}" # pylint: disable=line-too-long
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def fit(
|
| 103 |
+
self,
|
| 104 |
+
train_iter: DataLoader,
|
| 105 |
+
criterion: nn.Module,
|
| 106 |
+
optimizer: Optimizer,
|
| 107 |
+
eval_iter: Optional[DataLoader] = None,
|
| 108 |
+
epochs: int = 10,
|
| 109 |
+
network_model_path_base: Optional[str] = None,
|
| 110 |
+
save_every: Optional[int] = None,
|
| 111 |
+
evaluate_every: Optional[int] = None,
|
| 112 |
+
) -> None:
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
train_iter: iterator for training
|
| 117 |
+
criterion: loss function
|
| 118 |
+
optimizer: optimizer for the algorithm
|
| 119 |
+
eval_iter: iterator for evaluation
|
| 120 |
+
epochs: amount of epochs
|
| 121 |
+
network_model_path_base: where to save the models
|
| 122 |
+
save_every: saving model checkpoints every specified amount of epochs
|
| 123 |
+
evaluate_every: perform evaluation every specified amount of epochs.
|
| 124 |
+
If the evaluation is expensive, you probably want to
|
| 125 |
+
choose a high value for this
|
| 126 |
+
"""
|
| 127 |
+
criterion = criterion.to(self.device)
|
| 128 |
+
self.notify_callbacks("on_training_start", epochs)
|
| 129 |
+
|
| 130 |
+
for epoch in range(epochs):
|
| 131 |
+
train_loss = self.do_epoch(
|
| 132 |
+
criterion=criterion,
|
| 133 |
+
optimizer=optimizer,
|
| 134 |
+
data_iter=train_iter,
|
| 135 |
+
epoch=epoch,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if save_every and network_model_path_base and epoch % save_every == 0:
|
| 139 |
+
logging.info(f"Save the model after epoch {epoch}")
|
| 140 |
+
self.save(os.path.join(network_model_path_base, f"epoch_{epoch}.pt"))
|
| 141 |
+
|
| 142 |
+
val_loss = None
|
| 143 |
+
if eval_iter and evaluate_every and epoch % evaluate_every == 0:
|
| 144 |
+
logging.info(f"Evaluating after epoch {epoch}")
|
| 145 |
+
val_loss = self.evaluate(
|
| 146 |
+
criterion=criterion,
|
| 147 |
+
data_iter=eval_iter,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.notify_callbacks("on_training_iteration_end", train_loss, val_loss)
|
| 151 |
+
|
| 152 |
+
self.notify_callbacks("on_training_end", self.model)
|
| 153 |
+
# Save the last model anyway...
|
| 154 |
+
if network_model_path_base:
|
| 155 |
+
self.save(os.path.join(network_model_path_base, f"epoch_{epoch + 1}.pt"))
|
| 156 |
+
|
| 157 |
+
def evaluate(self, criterion: nn.Module, data_iter: DataLoader) -> float:
|
| 158 |
+
"""
|
| 159 |
+
Evaluates the model
|
| 160 |
+
Args:
|
| 161 |
+
criterion: Loss function for calculating the evaluation
|
| 162 |
+
data_iter: torch data iterator
|
| 163 |
+
"""
|
| 164 |
+
self.eval()
|
| 165 |
+
self.notify_callbacks("on_evaluation_start", len(data_iter))
|
| 166 |
+
total_loss = 0
|
| 167 |
+
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
for iteration, (batch, targets) in enumerate(data_iter):
|
| 170 |
+
batch = self.data_to_device(batch, self.device)
|
| 171 |
+
targets = self.data_to_device(targets, self.device)
|
| 172 |
+
|
| 173 |
+
outputs = self.model(batch)
|
| 174 |
+
loss = criterion(outputs, targets)
|
| 175 |
+
|
| 176 |
+
self.notify_callbacks(
|
| 177 |
+
"on_evaluation_step",
|
| 178 |
+
iteration,
|
| 179 |
+
outputs.detach().cpu(),
|
| 180 |
+
targets.detach().cpu(),
|
| 181 |
+
loss.item(),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
total_loss += loss.item()
|
| 185 |
+
|
| 186 |
+
loss = total_loss / len(data_iter)
|
| 187 |
+
self.notify_callbacks("on_evaluation_end")
|
| 188 |
+
return loss
|
| 189 |
+
|
| 190 |
+
def do_epoch(
|
| 191 |
+
self,
|
| 192 |
+
criterion: nn.Module,
|
| 193 |
+
optimizer: Optimizer,
|
| 194 |
+
data_iter: DataLoader,
|
| 195 |
+
epoch: int,
|
| 196 |
+
) -> float:
|
| 197 |
+
"""Perform a whole epoch.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
criterion (nn.Module): Loss function to be used.
|
| 201 |
+
optimizer (Optimizer): Optimizer to use for minimizing the loss function.
|
| 202 |
+
data_iter (DataLoader): Loader for data samples used for training the model.
|
| 203 |
+
epoch (int): The epoch number.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
float: Average training loss calculated during the epoch.
|
| 207 |
+
"""
|
| 208 |
+
total_loss = 0
|
| 209 |
+
total_time = 0.0
|
| 210 |
+
self.train()
|
| 211 |
+
self.notify_callbacks("on_epoch_start", epoch, len(data_iter))
|
| 212 |
+
for iteration, (batch, targets) in enumerate(data_iter):
|
| 213 |
+
self.iteration += 1
|
| 214 |
+
start_time = time.time()
|
| 215 |
+
batch = self.data_to_device(batch, self.device)
|
| 216 |
+
targets = self.data_to_device(targets, self.device)
|
| 217 |
+
|
| 218 |
+
outputs = self.model(batch)
|
| 219 |
+
|
| 220 |
+
loss = criterion(outputs, targets)
|
| 221 |
+
|
| 222 |
+
# Backward and optimize
|
| 223 |
+
optimizer.zero_grad()
|
| 224 |
+
loss.backward()
|
| 225 |
+
optimizer.step()
|
| 226 |
+
|
| 227 |
+
total_loss += loss.item()
|
| 228 |
+
|
| 229 |
+
end_time = time.time()
|
| 230 |
+
|
| 231 |
+
total_time += end_time - start_time
|
| 232 |
+
|
| 233 |
+
self.notify_callbacks(
|
| 234 |
+
"on_epoch_step",
|
| 235 |
+
self.iteration,
|
| 236 |
+
iteration,
|
| 237 |
+
loss.item(),
|
| 238 |
+
)
|
| 239 |
+
self.iteration += 1
|
| 240 |
+
|
| 241 |
+
loss = total_loss / len(data_iter)
|
| 242 |
+
|
| 243 |
+
self.notify_callbacks("on_epoch_end", loss)
|
| 244 |
+
return loss
|
| 245 |
+
|
| 246 |
+
def data_to_device(
|
| 247 |
+
self, data: Union[Tensor, List[Tensor]], device: Device
|
| 248 |
+
) -> Union[Tensor, List[Tensor]]:
|
| 249 |
+
"""
|
| 250 |
+
Transfers a tensor data to a device
|
| 251 |
+
Args:
|
| 252 |
+
data: torch tensor
|
| 253 |
+
device: target device
|
| 254 |
+
"""
|
| 255 |
+
if isinstance(data, list):
|
| 256 |
+
data = [d.to(device) for d in data]
|
| 257 |
+
elif isinstance(data, tuple):
|
| 258 |
+
data = tuple([d.to(device) for d in data])
|
| 259 |
+
else:
|
| 260 |
+
data = data.to(device)
|
| 261 |
+
|
| 262 |
+
return data
|
| 263 |
+
|
| 264 |
+
def save(self, model_path: str) -> None:
|
| 265 |
+
"""Saves the model to the given path.
|
| 266 |
+
|
| 267 |
+
If currently using data parallel, the method
|
| 268 |
+
will save the original model and not the data parallel instance of it
|
| 269 |
+
Args:
|
| 270 |
+
model_path: target path to save the model to
|
| 271 |
+
"""
|
| 272 |
+
if self.is_data_parallel:
|
| 273 |
+
torch.save(self.model.module, model_path)
|
| 274 |
+
else:
|
| 275 |
+
torch.save(self.model, model_path)
|
| 276 |
+
|
| 277 |
+
def get_model(self) -> nn.Module:
|
| 278 |
+
if self.is_data_parallel:
|
| 279 |
+
return self.model.module
|
| 280 |
+
|
| 281 |
+
return self.model
|
| 282 |
+
|
| 283 |
+
def forward(self, *args, **kwargs):
|
| 284 |
+
return self.model(*args, **kwargs)
|
web_demo/network/__init__.py
ADDED
|
File without changes
|
web_demo/network/__pycache__/MFNET.cpython-311.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
web_demo/network/__pycache__/TorchUtils.cpython-311.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
web_demo/network/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
web_demo/network/__pycache__/anomaly_detector_model.cpython-311.pyc
ADDED
|
Binary file (9.39 kB). View file
|
|
|
web_demo/network/__pycache__/c3d.cpython-311.pyc
ADDED
|
Binary file (6.81 kB). View file
|
|
|
web_demo/network/__pycache__/resnet.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
web_demo/network/anomaly_detector_model.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains an implementation of anomaly detector for videos."""
|
| 2 |
+
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AnomalyDetector(nn.Module):
|
| 10 |
+
"""Anomaly detection model for videos."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, input_dim=4096) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.fc1 = nn.Linear(input_dim, 512)
|
| 15 |
+
self.relu1 = nn.ReLU()
|
| 16 |
+
self.dropout1 = nn.Dropout(0.6)
|
| 17 |
+
|
| 18 |
+
self.fc2 = nn.Linear(512, 32)
|
| 19 |
+
self.dropout2 = nn.Dropout(0.6)
|
| 20 |
+
|
| 21 |
+
self.fc3 = nn.Linear(32, 1)
|
| 22 |
+
self.sig = nn.Sigmoid()
|
| 23 |
+
|
| 24 |
+
# In the original keras code they use "glorot_normal"
|
| 25 |
+
# As I understand, this is the same as xavier normal in Pytorch
|
| 26 |
+
nn.init.xavier_normal_(self.fc1.weight)
|
| 27 |
+
nn.init.xavier_normal_(self.fc2.weight)
|
| 28 |
+
nn.init.xavier_normal_(self.fc3.weight)
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def input_dim(self) -> int:
|
| 32 |
+
return self.fc1.weight.shape[1]
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor: # pylint: disable=arguments-differ
|
| 35 |
+
x = self.dropout1(self.relu1(self.fc1(x)))
|
| 36 |
+
x = self.dropout2(self.fc2(x))
|
| 37 |
+
x = self.sig(self.fc3(x))
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def custom_objective(y_pred: Tensor, y_true: Tensor) -> Tensor:
|
| 42 |
+
"""Calculate loss function with regularization for anomaly detection.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
y_pred (Tensor): A tensor containing the predictions of the model.
|
| 46 |
+
y_true (Tensor): A tensor containing the ground truth.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Tensor: A single dimension tensor containing the calculated loss.
|
| 50 |
+
"""
|
| 51 |
+
# y_pred (batch_size, 32, 1)
|
| 52 |
+
# y_true (batch_size)
|
| 53 |
+
lambdas = 8e-5
|
| 54 |
+
|
| 55 |
+
normal_vids_indices = torch.where(y_true == 0)
|
| 56 |
+
anomal_vids_indices = torch.where(y_true == 1)
|
| 57 |
+
|
| 58 |
+
normal_segments_scores = y_pred[normal_vids_indices].squeeze(-1) # (batch/2, 32, 1)
|
| 59 |
+
anomal_segments_scores = y_pred[anomal_vids_indices].squeeze(-1) # (batch/2, 32, 1)
|
| 60 |
+
|
| 61 |
+
# get the max score for each video
|
| 62 |
+
normal_segments_scores_maxes = normal_segments_scores.max(dim=-1)[0]
|
| 63 |
+
anomal_segments_scores_maxes = anomal_segments_scores.max(dim=-1)[0]
|
| 64 |
+
|
| 65 |
+
hinge_loss = 1 - anomal_segments_scores_maxes + normal_segments_scores_maxes
|
| 66 |
+
hinge_loss = torch.max(hinge_loss, torch.zeros_like(hinge_loss))
|
| 67 |
+
|
| 68 |
+
# Smoothness of anomalous video
|
| 69 |
+
smoothed_scores = anomal_segments_scores[:, 1:] - anomal_segments_scores[:, :-1]
|
| 70 |
+
smoothed_scores_sum_squared = smoothed_scores.pow(2).sum(dim=-1)
|
| 71 |
+
|
| 72 |
+
# Sparsity of anomalous video
|
| 73 |
+
sparsity_loss = anomal_segments_scores.sum(dim=-1)
|
| 74 |
+
|
| 75 |
+
final_loss = (
|
| 76 |
+
hinge_loss + lambdas * smoothed_scores_sum_squared + lambdas * sparsity_loss
|
| 77 |
+
).mean()
|
| 78 |
+
return final_loss
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class RegularizedLoss(torch.nn.Module):
|
| 82 |
+
"""Regularizes a loss function."""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
model: AnomalyDetector,
|
| 87 |
+
original_objective: Callable,
|
| 88 |
+
lambdas: float = 0.001,
|
| 89 |
+
) -> None:
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.lambdas = lambdas
|
| 92 |
+
self.model = model
|
| 93 |
+
self.objective = original_objective
|
| 94 |
+
|
| 95 |
+
def forward(self, y_pred: Tensor, y_true: Tensor): # pylint: disable=arguments-differ
|
| 96 |
+
# loss
|
| 97 |
+
# Our loss is defined with respect to l2 regularization, as used in the original keras code
|
| 98 |
+
fc1_params = torch.cat(tuple([x.view(-1) for x in self.model.fc1.parameters()]))
|
| 99 |
+
fc2_params = torch.cat(tuple([x.view(-1) for x in self.model.fc2.parameters()]))
|
| 100 |
+
fc3_params = torch.cat(tuple([x.view(-1) for x in self.model.fc3.parameters()]))
|
| 101 |
+
|
| 102 |
+
l1_regularization = self.lambdas * torch.norm(fc1_params, p=2)
|
| 103 |
+
l2_regularization = self.lambdas * torch.norm(fc2_params, p=2)
|
| 104 |
+
l3_regularization = self.lambdas * torch.norm(fc3_params, p=2)
|
| 105 |
+
|
| 106 |
+
return (
|
| 107 |
+
self.objective(y_pred, y_true)
|
| 108 |
+
+ l1_regularization
|
| 109 |
+
+ l2_regularization
|
| 110 |
+
+ l3_regularization
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 117 |
+
class AnomalyClassifier(nn.Module):
|
| 118 |
+
"""
|
| 119 |
+
Multi-class anomaly classifier
|
| 120 |
+
Supports 13 categories: Normal + 12 anomaly classes
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, input_dim=512, num_classes=13):
|
| 124 |
+
super(AnomalyClassifier, self).__init__()
|
| 125 |
+
self.fc1 = nn.Linear(input_dim, 256)
|
| 126 |
+
self.relu1 = nn.ReLU()
|
| 127 |
+
self.dropout1 = nn.Dropout(0.5)
|
| 128 |
+
|
| 129 |
+
self.fc2 = nn.Linear(256, 64)
|
| 130 |
+
self.relu2 = nn.ReLU()
|
| 131 |
+
self.dropout2 = nn.Dropout(0.5)
|
| 132 |
+
|
| 133 |
+
self.fc3 = nn.Linear(64, num_classes) # ✅ 13 outputs
|
| 134 |
+
|
| 135 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 136 |
+
"""
|
| 137 |
+
x: (B, input_dim) feature vectors
|
| 138 |
+
returns: (B, num_classes) logits
|
| 139 |
+
"""
|
| 140 |
+
x = self.dropout1(self.relu1(self.fc1(x)))
|
| 141 |
+
x = self.dropout2(self.relu2(self.fc2(x)))
|
| 142 |
+
return self.fc3(x)
|
web_demo/network/c3d.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" "This module contains an implementation of C3D model for video
|
| 2 |
+
processing."""
|
| 3 |
+
|
| 4 |
+
import itertools
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class C3D(nn.Module):
|
| 11 |
+
"""The C3D network."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, pretrained=None):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.pretrained = pretrained
|
| 17 |
+
|
| 18 |
+
self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
|
| 19 |
+
self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 20 |
+
|
| 21 |
+
self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
|
| 22 |
+
self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
|
| 23 |
+
|
| 24 |
+
self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
|
| 25 |
+
self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
|
| 26 |
+
self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
|
| 27 |
+
|
| 28 |
+
self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
|
| 29 |
+
self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
|
| 30 |
+
self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
|
| 31 |
+
|
| 32 |
+
self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
|
| 33 |
+
self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
|
| 34 |
+
self.pool5 = nn.MaxPool3d(
|
| 35 |
+
kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
self.fc6 = nn.Linear(8192, 4096)
|
| 39 |
+
self.relu = nn.ReLU()
|
| 40 |
+
self.__init_weight()
|
| 41 |
+
|
| 42 |
+
if pretrained:
|
| 43 |
+
self.__load_pretrained_weights()
|
| 44 |
+
|
| 45 |
+
def forward(self, x: Tensor):
|
| 46 |
+
x = self.relu(self.conv1(x))
|
| 47 |
+
x = self.pool1(x)
|
| 48 |
+
x = self.relu(self.conv2(x))
|
| 49 |
+
x = self.pool2(x)
|
| 50 |
+
x = self.relu(self.conv3a(x))
|
| 51 |
+
x = self.relu(self.conv3b(x))
|
| 52 |
+
x = self.pool3(x)
|
| 53 |
+
x = self.relu(self.conv4a(x))
|
| 54 |
+
x = self.relu(self.conv4b(x))
|
| 55 |
+
x = self.pool4(x)
|
| 56 |
+
x = self.relu(self.conv5a(x))
|
| 57 |
+
x = self.relu(self.conv5b(x))
|
| 58 |
+
x = self.pool5(x)
|
| 59 |
+
# x = x.view(-1, 8192)
|
| 60 |
+
x = x.view(x.size(0), -1) # changed
|
| 61 |
+
x = self.relu(self.fc6(x))
|
| 62 |
+
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
def __load_pretrained_weights(self):
|
| 66 |
+
"""Initialiaze network."""
|
| 67 |
+
corresp_name = [
|
| 68 |
+
# Conv1
|
| 69 |
+
"conv1.weight",
|
| 70 |
+
"conv1.bias",
|
| 71 |
+
# Conv2
|
| 72 |
+
"conv2.weight",
|
| 73 |
+
"conv2.bias",
|
| 74 |
+
# Conv3a
|
| 75 |
+
"conv3a.weight",
|
| 76 |
+
"conv3a.bias",
|
| 77 |
+
# Conv3b
|
| 78 |
+
"conv3b.weight",
|
| 79 |
+
"conv3b.bias",
|
| 80 |
+
# Conv4a
|
| 81 |
+
"conv4a.weight",
|
| 82 |
+
"conv4a.bias",
|
| 83 |
+
# Conv4b
|
| 84 |
+
"conv4b.weight",
|
| 85 |
+
"conv4b.bias",
|
| 86 |
+
# Conv5a
|
| 87 |
+
"conv5a.weight",
|
| 88 |
+
"conv5a.bias",
|
| 89 |
+
# Conv5b
|
| 90 |
+
"conv5b.weight",
|
| 91 |
+
"conv5b.bias",
|
| 92 |
+
# fc6
|
| 93 |
+
"fc6.weight",
|
| 94 |
+
"fc6.bias",
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
ignored_weights = [
|
| 98 |
+
f"{layer}.{type_}"
|
| 99 |
+
for layer, type_ in itertools.product(["fc7", "fc8"], ["bias", "weight"])
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
p_dict = torch.load(self.pretrained)
|
| 103 |
+
s_dict = self.state_dict()
|
| 104 |
+
for name in p_dict:
|
| 105 |
+
if name not in corresp_name:
|
| 106 |
+
if name in ignored_weights:
|
| 107 |
+
continue
|
| 108 |
+
print("no corresponding::", name)
|
| 109 |
+
continue
|
| 110 |
+
s_dict[name] = p_dict[name]
|
| 111 |
+
self.load_state_dict(s_dict)
|
| 112 |
+
|
| 113 |
+
def __init_weight(self):
|
| 114 |
+
"""Initialize weights of the model."""
|
| 115 |
+
for m in self.modules():
|
| 116 |
+
if isinstance(m, nn.Conv3d):
|
| 117 |
+
torch.nn.init.kaiming_normal_(m.weight)
|
| 118 |
+
elif isinstance(m, nn.BatchNorm3d):
|
| 119 |
+
m.weight.data.fill_(1)
|
| 120 |
+
m.bias.data.zero_()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
inputs = torch.ones((1, 3, 16, 112, 112))
|
| 125 |
+
net = C3D(pretrained=False)
|
| 126 |
+
|
| 127 |
+
outputs = net.forward(inputs)
|
| 128 |
+
print(outputs.size())
|
| 129 |
+
|
web_demo/network/resnet.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" "This module contains an implementation of ResNet model for video
|
| 2 |
+
processing."""
|
| 3 |
+
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_inplanes():
|
| 12 |
+
return [64, 128, 256, 512]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def conv3x3x3(in_planes, out_planes, stride=1):
|
| 16 |
+
return nn.Conv3d(
|
| 17 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def conv1x1x1(in_planes, out_planes, stride=1):
|
| 22 |
+
return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BasicBlock(nn.Module):
|
| 26 |
+
expansion = 1
|
| 27 |
+
|
| 28 |
+
def __init__(self, in_planes, planes, stride=1, downsample=None):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
self.conv1 = conv3x3x3(in_planes, planes, stride)
|
| 32 |
+
self.bn1 = nn.BatchNorm3d(planes)
|
| 33 |
+
self.relu = nn.ReLU(inplace=True)
|
| 34 |
+
self.conv2 = conv3x3x3(planes, planes)
|
| 35 |
+
self.bn2 = nn.BatchNorm3d(planes)
|
| 36 |
+
self.downsample = downsample
|
| 37 |
+
self.stride = stride
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
residual = x
|
| 41 |
+
|
| 42 |
+
out = self.conv1(x)
|
| 43 |
+
out = self.bn1(out)
|
| 44 |
+
out = self.relu(out)
|
| 45 |
+
|
| 46 |
+
out = self.conv2(out)
|
| 47 |
+
out = self.bn2(out)
|
| 48 |
+
|
| 49 |
+
if self.downsample is not None:
|
| 50 |
+
residual = self.downsample(x)
|
| 51 |
+
|
| 52 |
+
out += residual
|
| 53 |
+
out = self.relu(out)
|
| 54 |
+
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Bottleneck(nn.Module):
|
| 59 |
+
expansion = 4
|
| 60 |
+
|
| 61 |
+
def __init__(self, in_planes, planes, stride=1, downsample=None):
|
| 62 |
+
super().__init__()
|
| 63 |
+
|
| 64 |
+
self.conv1 = conv1x1x1(in_planes, planes)
|
| 65 |
+
self.bn1 = nn.BatchNorm3d(planes)
|
| 66 |
+
self.conv2 = conv3x3x3(planes, planes, stride)
|
| 67 |
+
self.bn2 = nn.BatchNorm3d(planes)
|
| 68 |
+
self.conv3 = conv1x1x1(planes, planes * self.expansion)
|
| 69 |
+
self.bn3 = nn.BatchNorm3d(planes * self.expansion)
|
| 70 |
+
self.relu = nn.ReLU(inplace=True)
|
| 71 |
+
self.downsample = downsample
|
| 72 |
+
self.stride = stride
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
residual = x
|
| 76 |
+
|
| 77 |
+
out = self.conv1(x)
|
| 78 |
+
out = self.bn1(out)
|
| 79 |
+
out = self.relu(out)
|
| 80 |
+
|
| 81 |
+
out = self.conv2(out)
|
| 82 |
+
out = self.bn2(out)
|
| 83 |
+
out = self.relu(out)
|
| 84 |
+
|
| 85 |
+
out = self.conv3(out)
|
| 86 |
+
out = self.bn3(out)
|
| 87 |
+
|
| 88 |
+
if self.downsample is not None:
|
| 89 |
+
residual = self.downsample(x)
|
| 90 |
+
|
| 91 |
+
out += residual
|
| 92 |
+
out = self.relu(out)
|
| 93 |
+
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ResNet(nn.Module):
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
block,
|
| 101 |
+
layers,
|
| 102 |
+
block_inplanes,
|
| 103 |
+
n_input_channels=3,
|
| 104 |
+
conv1_t_size=7,
|
| 105 |
+
conv1_t_stride=1,
|
| 106 |
+
no_max_pool=False,
|
| 107 |
+
shortcut_type="B",
|
| 108 |
+
widen_factor=1.0,
|
| 109 |
+
n_classes=1039,
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
|
| 113 |
+
block_inplanes = [int(x * widen_factor) for x in block_inplanes]
|
| 114 |
+
|
| 115 |
+
self.in_planes = block_inplanes[0]
|
| 116 |
+
self.no_max_pool = no_max_pool
|
| 117 |
+
|
| 118 |
+
self.conv1 = nn.Conv3d(
|
| 119 |
+
n_input_channels,
|
| 120 |
+
self.in_planes,
|
| 121 |
+
kernel_size=(conv1_t_size, 7, 7),
|
| 122 |
+
stride=(conv1_t_stride, 2, 2),
|
| 123 |
+
padding=(conv1_t_size // 2, 3, 3),
|
| 124 |
+
bias=False,
|
| 125 |
+
)
|
| 126 |
+
self.bn1 = nn.BatchNorm3d(self.in_planes)
|
| 127 |
+
self.relu = nn.ReLU(inplace=True)
|
| 128 |
+
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
|
| 129 |
+
self.layer1 = self._make_layer(
|
| 130 |
+
block, block_inplanes[0], layers[0], shortcut_type
|
| 131 |
+
)
|
| 132 |
+
self.layer2 = self._make_layer(
|
| 133 |
+
block, block_inplanes[1], layers[1], shortcut_type, stride=2
|
| 134 |
+
)
|
| 135 |
+
self.layer3 = self._make_layer(
|
| 136 |
+
block, block_inplanes[2], layers[2], shortcut_type, stride=2
|
| 137 |
+
)
|
| 138 |
+
self.layer4 = self._make_layer(
|
| 139 |
+
block, block_inplanes[3], layers[3], shortcut_type, stride=2
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
| 143 |
+
# self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes)
|
| 144 |
+
|
| 145 |
+
for m in self.modules():
|
| 146 |
+
if isinstance(m, nn.Conv3d):
|
| 147 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 148 |
+
elif isinstance(m, nn.BatchNorm3d):
|
| 149 |
+
nn.init.constant_(m.weight, 1)
|
| 150 |
+
nn.init.constant_(m.bias, 0)
|
| 151 |
+
|
| 152 |
+
def _downsample_basic_block(self, x, planes, stride):
|
| 153 |
+
out = F.avg_pool3d(x, kernel_size=1, stride=stride)
|
| 154 |
+
zero_pads = torch.zeros(
|
| 155 |
+
out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)
|
| 156 |
+
)
|
| 157 |
+
if isinstance(out.data, torch.cuda.FloatTensor):
|
| 158 |
+
zero_pads = zero_pads.cuda()
|
| 159 |
+
|
| 160 |
+
out = torch.cat([out.data, zero_pads], dim=1)
|
| 161 |
+
|
| 162 |
+
return out
|
| 163 |
+
|
| 164 |
+
def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
|
| 165 |
+
downsample = None
|
| 166 |
+
if stride != 1 or self.in_planes != planes * block.expansion:
|
| 167 |
+
if shortcut_type == "A":
|
| 168 |
+
downsample = partial(
|
| 169 |
+
self._downsample_basic_block,
|
| 170 |
+
planes=planes * block.expansion,
|
| 171 |
+
stride=stride,
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
downsample = nn.Sequential(
|
| 175 |
+
conv1x1x1(self.in_planes, planes * block.expansion, stride),
|
| 176 |
+
nn.BatchNorm3d(planes * block.expansion),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
layers = []
|
| 180 |
+
layers.append(
|
| 181 |
+
block(
|
| 182 |
+
in_planes=self.in_planes,
|
| 183 |
+
planes=planes,
|
| 184 |
+
stride=stride,
|
| 185 |
+
downsample=downsample,
|
| 186 |
+
)
|
| 187 |
+
)
|
| 188 |
+
self.in_planes = planes * block.expansion
|
| 189 |
+
for _ in range(1, blocks):
|
| 190 |
+
layers.append(block(self.in_planes, planes))
|
| 191 |
+
|
| 192 |
+
return nn.Sequential(*layers)
|
| 193 |
+
|
| 194 |
+
def forward(self, x):
|
| 195 |
+
x = self.conv1(x)
|
| 196 |
+
x = self.bn1(x)
|
| 197 |
+
x = self.relu(x)
|
| 198 |
+
if not self.no_max_pool:
|
| 199 |
+
x = self.maxpool(x)
|
| 200 |
+
|
| 201 |
+
x = self.layer1(x)
|
| 202 |
+
x = self.layer2(x)
|
| 203 |
+
x = self.layer3(x)
|
| 204 |
+
x = self.layer4(x)
|
| 205 |
+
|
| 206 |
+
x = self.avgpool(x)
|
| 207 |
+
|
| 208 |
+
x = x.view(x.size(0), -1)
|
| 209 |
+
# x = self.fc(x)
|
| 210 |
+
|
| 211 |
+
return x
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def generate_model(model_depth, **kwargs):
|
| 215 |
+
assert model_depth in [10, 18, 34, 50, 101, 152, 200]
|
| 216 |
+
|
| 217 |
+
if model_depth == 10:
|
| 218 |
+
model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs)
|
| 219 |
+
elif model_depth == 18:
|
| 220 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs)
|
| 221 |
+
elif model_depth == 34:
|
| 222 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs)
|
| 223 |
+
elif model_depth == 50:
|
| 224 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs)
|
| 225 |
+
elif model_depth == 101:
|
| 226 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs)
|
| 227 |
+
elif model_depth == 152:
|
| 228 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs)
|
| 229 |
+
elif model_depth == 200:
|
| 230 |
+
model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs)
|
| 231 |
+
|
| 232 |
+
return model
|
web_demo/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask
|
| 2 |
+
flask-socketio
|
| 3 |
+
eventlet
|
| 4 |
+
torch
|
| 5 |
+
numpy
|
| 6 |
+
opencv-python
|
| 7 |
+
matplotlib
|
web_demo/static/css/style.css
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
body {
|
| 2 |
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
|
| 3 |
+
background-color: #121212;
|
| 4 |
+
color: #e0e0e0;
|
| 5 |
+
margin: 0;
|
| 6 |
+
padding: 20px;
|
| 7 |
+
display: flex;
|
| 8 |
+
justify-content: center;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
.container {
|
| 12 |
+
display: flex;
|
| 13 |
+
width: 100%;
|
| 14 |
+
max-width: 1600px;
|
| 15 |
+
gap: 20px;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
/* CHANGED: Main content takes more space, sidebar takes less */
|
| 19 |
+
.main-content {
|
| 20 |
+
flex: 4; /* Increased from 3 */
|
| 21 |
+
display: flex;
|
| 22 |
+
flex-direction: column;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
.sidebar {
|
| 26 |
+
flex: 1; /* Stays at 1, making it proportionally smaller */
|
| 27 |
+
background-color: #1e1e1e;
|
| 28 |
+
padding: 20px;
|
| 29 |
+
border-radius: 8px;
|
| 30 |
+
height: fit-content;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.header {
|
| 34 |
+
display: flex;
|
| 35 |
+
justify-content: space-between;
|
| 36 |
+
align-items: center;
|
| 37 |
+
margin-bottom: 10px;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
h1 { border-bottom: none; padding-bottom: 0; }
|
| 41 |
+
h2, h3 { color: #ffffff; border-bottom: 2px solid #333; padding-bottom: 10px; }
|
| 42 |
+
|
| 43 |
+
/* CHANGED: Grid ratio adjusted to make the graph wider */
|
| 44 |
+
.dashboard-grid {
|
| 45 |
+
display: grid;
|
| 46 |
+
grid-template-columns: 1.8fr 1.5fr; /* Video area vs Graph area */
|
| 47 |
+
gap: 20px;
|
| 48 |
+
align-items: flex-start;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
.video-area {
|
| 52 |
+
display: flex;
|
| 53 |
+
flex-direction: column;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
.video-wrapper {
|
| 57 |
+
width: 100%;
|
| 58 |
+
margin-bottom: 10px; /* Space between video and status label */
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
#videoPlayer {
|
| 62 |
+
background-color: #000;
|
| 63 |
+
border-radius: 8px;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
/* CHANGED: Status label is now positioned under the video */
|
| 67 |
+
#statusLabel {
|
| 68 |
+
margin-top: 0; /* Resets previous margin */
|
| 69 |
+
font-style: italic;
|
| 70 |
+
color: #f44336;
|
| 71 |
+
text-align: center; /* Center the text under the video */
|
| 72 |
+
min-height: 24px; /* Prevents layout shifts */
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
/* The chart and yolo containers are now styled independently */
|
| 76 |
+
.chart-container {
|
| 77 |
+
background-color: #1e1e1e;
|
| 78 |
+
padding: 20px;
|
| 79 |
+
border-radius: 8px;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.yolo-container {
|
| 83 |
+
background-color: #1e1e1e;
|
| 84 |
+
padding: 20px;
|
| 85 |
+
border-radius: 8px;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
#yoloTextLabel { font-size: 1.2em; font-weight: bold; color: #4CAF50; min-height: 25px; }
|
| 89 |
+
#yoloImageFrame { width: 100%; height: auto; border-radius: 4px; background-color: #333; min-height: 150px; margin-top: 10px; }
|
| 90 |
+
|
| 91 |
+
/* Styles for controls in the sidebar */
|
| 92 |
+
.custom-select {
|
| 93 |
+
width: 100%;
|
| 94 |
+
padding: 12px 15px;
|
| 95 |
+
background-color: #3a3a3a;
|
| 96 |
+
color: #e0e0e0;
|
| 97 |
+
border: 1px solid #bb86fc;
|
| 98 |
+
border-radius: 4px;
|
| 99 |
+
font-size: 1em;
|
| 100 |
+
cursor: pointer;
|
| 101 |
+
}
|
| 102 |
+
.custom-select:hover { background-color: #4a4a4a; }
|
| 103 |
+
|
| 104 |
+
.separator { border: none; border-top: 1px solid #333; margin: 20px 0; }
|
| 105 |
+
.upload-section { display: flex; flex-direction: column; gap: 10px; }
|
| 106 |
+
#videoUpload { color: #e0e0e0; }
|
| 107 |
+
#videoUpload::file-selector-button { font-weight: bold; color: #bb86fc; background-color: #3a3a3a; padding: 8px 12px; border: 1px solid #bb86fc; border-radius: 4px; cursor: pointer; transition: background-color 0.2s; }
|
| 108 |
+
#videoUpload::file-selector-button:hover { background-color: #4a4a4a; }
|
| 109 |
+
#uploadButton { padding: 10px 20px; font-size: 16px; font-weight: bold; color: white; background-color: #03dac6; border: none; border-radius: 5px; cursor: pointer; transition: background-color 0.2s; }
|
| 110 |
+
#uploadButton:hover { background-color: #018786; }
|
| 111 |
+
#resetButton { padding: 10px 20px; font-size: 16px; font-weight: bold; color: white; background-color: #f44336; border: none; border-radius: 5px; cursor: pointer; transition: background-color 0.2s; }
|
| 112 |
+
#resetButton:hover { background-color: #d32f2f; }
|
web_demo/static/js/main.js
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
document.addEventListener('DOMContentLoaded', () => {
|
| 2 |
+
const socket = io();
|
| 3 |
+
|
| 4 |
+
const videoPlayer = document.getElementById('videoPlayer');
|
| 5 |
+
const yoloTextLabel = document.getElementById('yoloTextLabel');
|
| 6 |
+
const yoloImageFrame = document.getElementById('yoloImageFrame');
|
| 7 |
+
const statusLabel = document.getElementById('statusLabel');
|
| 8 |
+
const resetButton = document.getElementById('resetButton');
|
| 9 |
+
const videoUploadInput = document.getElementById('videoUpload');
|
| 10 |
+
const uploadButton = document.getElementById('uploadButton');
|
| 11 |
+
|
| 12 |
+
// CHANGED: Get the new dropdown selector
|
| 13 |
+
const anomalySelector = document.getElementById('anomalySelector');
|
| 14 |
+
|
| 15 |
+
let chart;
|
| 16 |
+
|
| 17 |
+
function initializeChart() {
|
| 18 |
+
const ctx = document.getElementById('anomalyChart').getContext('2d');
|
| 19 |
+
if (chart) { chart.destroy(); }
|
| 20 |
+
chart = new Chart(ctx, {
|
| 21 |
+
type: 'line', data: { labels: [], datasets: [{ label: 'Anomaly Score', data: [], borderColor: 'rgba(255, 99, 132, 1)', backgroundColor: 'rgba(255, 99, 132, 0.2)', borderWidth: 2, tension: 0.4, pointRadius: 0 }] }, options: { scales: { y: { beginAtZero: true, max: 1.0, ticks: { color: '#e0e0e0' }}, x: { ticks: { color: '#e0e0e0' }}}, plugins: { legend: { labels: { color: '#e0e0e0' }}}}
|
| 22 |
+
});
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
function resetUI() {
|
| 26 |
+
videoPlayer.pause();
|
| 27 |
+
videoPlayer.removeAttribute('src');
|
| 28 |
+
videoPlayer.load();
|
| 29 |
+
initializeChart();
|
| 30 |
+
yoloTextLabel.textContent = 'Waiting for anomaly...';
|
| 31 |
+
yoloImageFrame.src = '';
|
| 32 |
+
statusLabel.textContent = 'System reset. Select a video to begin.';
|
| 33 |
+
videoUploadInput.value = '';
|
| 34 |
+
anomalySelector.selectedIndex = 0; // Reset dropdown to the default option
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// --- WebSocket Event Listeners (unchanged) ---
|
| 38 |
+
socket.on('connect', () => { statusLabel.textContent = 'Connected. Please select a video to start processing.'; });
|
| 39 |
+
socket.on('update_graph', (data) => {
|
| 40 |
+
const { score } = data;
|
| 41 |
+
if (!chart) return;
|
| 42 |
+
const newLabel = chart.data.labels.length + 1;
|
| 43 |
+
chart.data.labels.push(newLabel);
|
| 44 |
+
chart.data.datasets[0].data.push(score);
|
| 45 |
+
if (chart.data.labels.length > 100) { chart.data.labels.shift(); chart.data.datasets[0].data.shift(); }
|
| 46 |
+
chart.update();
|
| 47 |
+
});
|
| 48 |
+
socket.on('update_yolo_text', (data) => { yoloTextLabel.textContent = data.text; });
|
| 49 |
+
socket.on('update_yolo_image', (data) => { yoloImageFrame.src = `data:image/jpeg;base64,${data.image_data}`; });
|
| 50 |
+
socket.on('update_status', (data) => { statusLabel.textContent = data.status; });
|
| 51 |
+
socket.on('processing_error', (data) => { statusLabel.textContent = `Error: ${data.error}`; });
|
| 52 |
+
socket.on('processing_finished', (data) => { statusLabel.textContent = data.message; });
|
| 53 |
+
socket.on('system_reset_confirm', () => { resetUI(); });
|
| 54 |
+
|
| 55 |
+
// --- User Interaction ---
|
| 56 |
+
|
| 57 |
+
// CHANGED: Replaced the old event listener for links with one for the dropdown
|
| 58 |
+
anomalySelector.addEventListener('change', (event) => {
|
| 59 |
+
const anomalyName = event.target.value;
|
| 60 |
+
if (!anomalyName) return; // Do nothing if the default option is selected
|
| 61 |
+
|
| 62 |
+
resetUI();
|
| 63 |
+
statusLabel.textContent = `Requesting to process ${anomalyName}...`;
|
| 64 |
+
|
| 65 |
+
videoPlayer.src = `/video_stream/demo/${anomalyName}`;
|
| 66 |
+
videoPlayer.play();
|
| 67 |
+
|
| 68 |
+
socket.emit('start_processing', { 'source': 'demo', 'filename': anomalyName });
|
| 69 |
+
});
|
| 70 |
+
|
| 71 |
+
resetButton.addEventListener('click', () => { socket.emit('reset_system'); });
|
| 72 |
+
|
| 73 |
+
// Upload button logic (unchanged)
|
| 74 |
+
uploadButton.addEventListener('click', () => {
|
| 75 |
+
const file = videoUploadInput.files[0];
|
| 76 |
+
if (!file) {
|
| 77 |
+
alert('Please select a video file first!');
|
| 78 |
+
return;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
resetUI();
|
| 82 |
+
statusLabel.textContent = 'Uploading video...';
|
| 83 |
+
|
| 84 |
+
const formData = new FormData();
|
| 85 |
+
formData.append('video', file);
|
| 86 |
+
|
| 87 |
+
fetch('/upload', { method: 'POST', body: formData })
|
| 88 |
+
.then(response => response.json())
|
| 89 |
+
.then(data => {
|
| 90 |
+
if (data.success) {
|
| 91 |
+
const uploadedFilename = data.filename;
|
| 92 |
+
statusLabel.textContent = `Upload successful. Starting analysis...`;
|
| 93 |
+
videoPlayer.src = `/video_stream/upload/${uploadedFilename}`;
|
| 94 |
+
videoPlayer.play();
|
| 95 |
+
socket.emit('start_processing', { 'source': 'upload', 'filename': uploadedFilename });
|
| 96 |
+
} else {
|
| 97 |
+
statusLabel.textContent = `Error: ${data.error}`;
|
| 98 |
+
alert(`Upload failed: ${data.error}`);
|
| 99 |
+
}
|
| 100 |
+
})
|
| 101 |
+
.catch(error => {
|
| 102 |
+
statusLabel.textContent = 'An error occurred during upload.';
|
| 103 |
+
console.error('Upload error:', error);
|
| 104 |
+
});
|
| 105 |
+
});
|
| 106 |
+
|
| 107 |
+
initializeChart();
|
| 108 |
+
});
|
web_demo/static/script.js
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
const videoPlayer = document.getElementById("videoPlayer");
|
| 2 |
+
const yoloResult = document.getElementById("yoloResult");
|
| 3 |
+
|
| 4 |
+
// Dummy chart for anomaly graph
|
| 5 |
+
const ctx = document.getElementById("anomalyGraph").getContext("2d");
|
| 6 |
+
const graph = new Chart(ctx, {
|
| 7 |
+
type: "line",
|
| 8 |
+
data: {
|
| 9 |
+
labels: [],
|
| 10 |
+
datasets: [{
|
| 11 |
+
label: "Anomaly Score",
|
| 12 |
+
data: [],
|
| 13 |
+
borderColor: "red",
|
| 14 |
+
borderWidth: 2
|
| 15 |
+
}]
|
| 16 |
+
},
|
| 17 |
+
options: {
|
| 18 |
+
responsive: true,
|
| 19 |
+
scales: {
|
| 20 |
+
y: { min: 0, max: 1 }
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
});
|
| 24 |
+
|
| 25 |
+
async function playDemo(name) {
|
| 26 |
+
const response = await fetch("/get_video", {
|
| 27 |
+
method: "POST",
|
| 28 |
+
headers: {"Content-Type": "application/json"},
|
| 29 |
+
body: JSON.stringify({ name })
|
| 30 |
+
});
|
| 31 |
+
|
| 32 |
+
const data = await response.json();
|
| 33 |
+
if (data.error) {
|
| 34 |
+
alert(data.error);
|
| 35 |
+
return;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// Load video
|
| 39 |
+
videoPlayer.src = "file:///" + data.path;
|
| 40 |
+
yoloResult.innerText = `Playing demo: ${name}`;
|
| 41 |
+
}
|
web_demo/static/videos/Abuse.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:425744aa3472e424d52d7ce97bf6d0bdd445ad62ad1be110095d2027a31550cc
|
| 3 |
+
size 6250495
|
web_demo/static/videos/Arrest.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:faf0f08b1ee989545ad1de2edecdb56a24e65914194b8083f47d10481926c0e1
|
| 3 |
+
size 11929804
|
web_demo/static/videos/Arson.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:902f3138fa8b839abd08bcd3e434e84756742fdf0c60bcc0769cd7106b1ac3a2
|
| 3 |
+
size 12694369
|
web_demo/static/videos/Assault.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b83cf948fef884ede2b86a2d3fe68de779b9c81301a5c653fbb329bfc243274
|
| 3 |
+
size 21066405
|
web_demo/static/videos/Burglary.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd17094bfd2e5b73bcce767c434f14b715744eb3338fb80f1a213c1a337ce65d
|
| 3 |
+
size 9857751
|
web_demo/static/videos/Explosion.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b462f9241ab7521e98071b18e8956c5a921336140b4da68ddbf56a5684e87fb6
|
| 3 |
+
size 5162883
|
web_demo/static/videos/Fighting.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a135cc99b9b7d1f314375cc5e29b6a38aa1131544bf0d9ca133a95644668abf6
|
| 3 |
+
size 5519077
|
web_demo/static/videos/Normal.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e7a4881043c8e9deefe11c65ed8663a281c8366a5baa91f091d67b98eb638018
|
| 3 |
+
size 7205089
|
web_demo/static/videos/RoadAccidents.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e6ccd7bac80120cfeac9a5ef3e726da29864fb8cfd218ea0ed42d696ce553ab
|
| 3 |
+
size 14490312
|
web_demo/static/videos/Robbery.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce7983bbb834708b8316c72cb916b9cab0105e2f283c7f8e636d38b36ddd6b48
|
| 3 |
+
size 26631485
|
web_demo/static/videos/Shooting.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b125ed267b82f514820cc568c7c820a0f04cd531500bd242003c8efd2f9bdcdf
|
| 3 |
+
size 2198741
|
web_demo/static/videos/Shoplifting.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:717d68d3671d3f7638f80cc7db2e682599fceee21f15385431c569a1480d42ab
|
| 3 |
+
size 22406639
|
web_demo/static/videos/Stealing.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:97ebf655ad4192fdfef01ec91c435f85d6e773257fe72a1458eacf5abdd2e04b
|
| 3 |
+
size 27565440
|
web_demo/static/videos/Vandalism.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:407508a2a3587caac3b3e4b165983f494692301e400ed4c4bbed504c47ba9e56
|
| 3 |
+
size 2851411
|
web_demo/templates/index.html
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Real-Time Anomaly Detection</title>
|
| 7 |
+
<link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}">
|
| 8 |
+
</head>
|
| 9 |
+
<body>
|
| 10 |
+
<div class="container">
|
| 11 |
+
<main class="main-content">
|
| 12 |
+
<div class="header">
|
| 13 |
+
<h1>Anomaly Detection Dashboard</h1>
|
| 14 |
+
<button id="resetButton">Reset</button>
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
<div class="dashboard-grid">
|
| 18 |
+
|
| 19 |
+
<div class="video-area">
|
| 20 |
+
<div class="video-wrapper">
|
| 21 |
+
<video id="videoPlayer" width="100%" controls muted>
|
| 22 |
+
Your browser does not support the video tag.
|
| 23 |
+
</video>
|
| 24 |
+
</div>
|
| 25 |
+
<p id="statusLabel">Select a video to begin.</p>
|
| 26 |
+
</div>
|
| 27 |
+
|
| 28 |
+
<div class="chart-container">
|
| 29 |
+
<h3>Live Anomaly Score</h3>
|
| 30 |
+
<canvas id="anomalyChart"></canvas>
|
| 31 |
+
</div>
|
| 32 |
+
|
| 33 |
+
</div>
|
| 34 |
+
</main>
|
| 35 |
+
|
| 36 |
+
<aside class="sidebar">
|
| 37 |
+
<h2>Demo Videos</h2>
|
| 38 |
+
<select id="anomalySelector" class="custom-select">
|
| 39 |
+
<option value="" disabled selected>Select a Demo Video...</option>
|
| 40 |
+
{% for name in anomaly_names %}
|
| 41 |
+
<option value="{{ name }}">{{ name }}</option>
|
| 42 |
+
{% endfor %}
|
| 43 |
+
</select>
|
| 44 |
+
|
| 45 |
+
<hr class="separator">
|
| 46 |
+
<h2>Upload Your Own</h2>
|
| 47 |
+
<div class="upload-section">
|
| 48 |
+
<input type="file" id="videoUpload" accept="video/mp4, video/mov, video/avi">
|
| 49 |
+
<button id="uploadButton">Analyze Uploaded Video</button>
|
| 50 |
+
</div>
|
| 51 |
+
|
| 52 |
+
<hr class="separator">
|
| 53 |
+
|
| 54 |
+
<div class="yolo-container">
|
| 55 |
+
<h3>YOLO Detection Result</h3>
|
| 56 |
+
<p id="yoloTextLabel">Waiting for anomaly...</p>
|
| 57 |
+
<img id="yoloImageFrame" src="" alt="YOLO Frame Preview">
|
| 58 |
+
</div>
|
| 59 |
+
</aside>
|
| 60 |
+
</div>
|
| 61 |
+
|
| 62 |
+
<script src="https://cdn.socket.io/4.7.5/socket.io.min.js"></script>
|
| 63 |
+
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
| 64 |
+
<script src="{{ url_for('static', filename='js/main.js') }}"></script>
|
| 65 |
+
</body>
|
| 66 |
+
</html>
|
web_demo/utils/__init__.py
ADDED
|
File without changes
|
web_demo/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (174 Bytes). View file
|
|
|
web_demo/utils/__pycache__/callbacks.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
web_demo/utils/__pycache__/functional_video.cpython-311.pyc
ADDED
|
Binary file (5.81 kB). View file
|
|
|
web_demo/utils/__pycache__/load_model.cpython-311.pyc
ADDED
|
Binary file (5.48 kB). View file
|
|
|
web_demo/utils/__pycache__/stack.cpython-311.pyc
ADDED
|
Binary file (2.37 kB). View file
|
|
|
web_demo/utils/__pycache__/transforms_video.cpython-311.pyc
ADDED
|
Binary file (8.39 kB). View file
|
|
|
web_demo/utils/__pycache__/types.cpython-311.pyc
ADDED
|
Binary file (629 Bytes). View file
|
|
|
web_demo/utils/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (3.69 kB). View file
|
|
|
web_demo/utils/callbacks.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains callbacks to be used along with `TorchModel`."""
|
| 2 |
+
|
| 3 |
+
import datetime
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Callback(ABC):
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def on_training_start(self, epochs) -> None:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def on_training_end(self, model) -> None:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def on_epoch_start(self, epoch_num, epoch_iterations) -> None:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def on_epoch_step(self, global_iteration, epoch_iteration, loss) -> None:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def on_epoch_end(self, loss) -> None:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def on_evaluation_start(self, val_iterations) -> None:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
@abstractmethod
|
| 38 |
+
def on_evaluation_step(self, iteration, model_outputs, targets, loss) -> None:
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def on_evaluation_end(self) -> None:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def on_training_iteration_end(self, train_loss, val_loss) -> None:
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class DefaultModelCallback(Callback):
|
| 51 |
+
"""A callback that simply logs the loss for epochs during training and
|
| 52 |
+
evaluation."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, log_every=10, visualization_dir=None) -> None:
|
| 55 |
+
"""
|
| 56 |
+
Args:
|
| 57 |
+
log_every (iterations): logging intervals
|
| 58 |
+
"""
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.visualization_dir = visualization_dir
|
| 61 |
+
self._log_every = log_every
|
| 62 |
+
self._epochs = 0
|
| 63 |
+
self._epoch = 0
|
| 64 |
+
self._epoch_iterations = 0
|
| 65 |
+
self._val_iterations = 0
|
| 66 |
+
self._start_time = 0.0
|
| 67 |
+
self._train_losses = []
|
| 68 |
+
self._val_loss = []
|
| 69 |
+
|
| 70 |
+
def on_training_start(self, epochs) -> None:
|
| 71 |
+
logging.info(f"Training for {epochs} epochs")
|
| 72 |
+
self._epochs = epochs
|
| 73 |
+
self._train_losses = []
|
| 74 |
+
self._val_loss = []
|
| 75 |
+
|
| 76 |
+
def on_training_end(self, model) -> None:
|
| 77 |
+
if self.visualization_dir is not None:
|
| 78 |
+
plt.figure()
|
| 79 |
+
plt.xlabel("Epoch")
|
| 80 |
+
plt.ylabel("Loss")
|
| 81 |
+
|
| 82 |
+
plt.plot(
|
| 83 |
+
range(1, self._epochs + 1), self._train_losses, label="Training loss"
|
| 84 |
+
)
|
| 85 |
+
if self._val_loss:
|
| 86 |
+
plt.plot(
|
| 87 |
+
range(1, self._epochs + 1), self._val_loss, label="Validation loss"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
plt.savefig(os.path.join(self.visualization_dir, "loss.png"))
|
| 91 |
+
plt.close()
|
| 92 |
+
|
| 93 |
+
def on_epoch_start(self, epoch_num: int, epoch_iterations: int) -> None:
|
| 94 |
+
self._epoch = epoch_num
|
| 95 |
+
self._epoch_iterations = epoch_iterations
|
| 96 |
+
self._start_time = time.time()
|
| 97 |
+
|
| 98 |
+
def on_epoch_step(
|
| 99 |
+
self, global_iteration: int, epoch_iteration: int, loss: float
|
| 100 |
+
) -> None:
|
| 101 |
+
if epoch_iteration % self._log_every == 0:
|
| 102 |
+
average_time = round(
|
| 103 |
+
(time.time() - self._start_time) / (epoch_iteration + 1), 3
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
loss_string = f"loss: {loss}"
|
| 107 |
+
|
| 108 |
+
# pylint: disable=line-too-long
|
| 109 |
+
logging.info(
|
| 110 |
+
f"Epoch {self._epoch}/{self._epochs} Iteration {epoch_iteration}/{self._epoch_iterations} {loss_string} Time: {average_time} seconds/iteration"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def on_epoch_end(self, loss) -> None:
|
| 114 |
+
self._train_losses.append(loss)
|
| 115 |
+
|
| 116 |
+
def on_evaluation_start(self, val_iterations) -> None:
|
| 117 |
+
self._val_iterations = val_iterations
|
| 118 |
+
|
| 119 |
+
def on_evaluation_step(self, iteration, model_outputs, targets, loss) -> None:
|
| 120 |
+
if iteration % self._log_every == 0:
|
| 121 |
+
logging.info(f"Iteration {iteration}/{self._val_iterations}")
|
| 122 |
+
|
| 123 |
+
def on_evaluation_end(self) -> None:
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
def on_training_iteration_end(self, train_loss, val_loss) -> None:
|
| 127 |
+
# pylint: disable=line-too-long
|
| 128 |
+
train_loss_string = f"Train loss: {train_loss}"
|
| 129 |
+
if val_loss:
|
| 130 |
+
val_loss_string = f"Validation loss: {val_loss}"
|
| 131 |
+
logging.info(
|
| 132 |
+
f"""
|
| 133 |
+
============================================================================================================================
|
| 134 |
+
Epoch {self._epoch}/{self._epochs} {train_loss_string} {val_loss_string} time: {datetime.timedelta(seconds=time.time() - self._start_time)}
|
| 135 |
+
============================================================================================================================
|
| 136 |
+
"""
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
else:
|
| 140 |
+
logging.info(
|
| 141 |
+
f"""
|
| 142 |
+
============================================================================================================================
|
| 143 |
+
Epoch {self._epoch}/{self._epochs} {train_loss_string} time: {datetime.timedelta(seconds=time.time() - self._start_time)}
|
| 144 |
+
============================================================================================================================
|
| 145 |
+
"""
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class TensorBoardCallback(Callback):
|
| 150 |
+
"""A callback that simply logs the loss for epochs during training and
|
| 151 |
+
evaluation."""
|
| 152 |
+
|
| 153 |
+
def __init__(self, tb_writer) -> None:
|
| 154 |
+
"""
|
| 155 |
+
Args:
|
| 156 |
+
tb_writer: tensorboard logger instance
|
| 157 |
+
"""
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.tb_writer = tb_writer
|
| 160 |
+
self.epoch = 0
|
| 161 |
+
|
| 162 |
+
def on_training_start(self, epochs) -> None:
|
| 163 |
+
pass
|
| 164 |
+
|
| 165 |
+
def on_training_end(self, model) -> None:
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
def on_epoch_start(self, epoch_num, epoch_iterations) -> None:
|
| 169 |
+
self.epoch = epoch_num
|
| 170 |
+
|
| 171 |
+
def on_epoch_step(self, global_iteration, epoch_iteration, loss) -> None:
|
| 172 |
+
self.tb_writer.add_scalars(
|
| 173 |
+
"Train loss (iterations)", {"Loss": loss}, global_iteration
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def on_epoch_end(self, loss) -> None:
|
| 177 |
+
pass
|
| 178 |
+
|
| 179 |
+
def on_evaluation_start(self, val_iterations) -> None:
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
def on_evaluation_step(self, iteration, model_outputs, targets, loss) -> None:
|
| 183 |
+
pass
|
| 184 |
+
|
| 185 |
+
def on_evaluation_end(self) -> None:
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
def on_training_iteration_end(self, train_loss, val_loss) -> None:
|
| 189 |
+
if train_loss is not None:
|
| 190 |
+
self.tb_writer.add_scalars(
|
| 191 |
+
"Epoch loss", {"Loss (train)": train_loss}, self.epoch
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if val_loss is not None:
|
| 195 |
+
self.tb_writer.add_scalars(
|
| 196 |
+
"Epoch loss", {"Loss (validation)": val_loss}, self.epoch
|
| 197 |
+
)
|
web_demo/utils/functional_video.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _is_tensor_video_clip(clip):
|
| 5 |
+
if not torch.is_tensor(clip):
|
| 6 |
+
raise TypeError(f"clip should be Tesnor. Got {type(clip)}")
|
| 7 |
+
|
| 8 |
+
if not clip.ndimension() == 4:
|
| 9 |
+
raise ValueError(f"clip should be 4D. Got {clip.dim()}D")
|
| 10 |
+
|
| 11 |
+
return True
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def crop(clip, i, j, h, w):
|
| 15 |
+
"""
|
| 16 |
+
Args:
|
| 17 |
+
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
| 18 |
+
"""
|
| 19 |
+
assert len(clip.size()) == 4, "clip should be a 4D tensor"
|
| 20 |
+
return clip[..., i : i + h, j : j + w]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def resize(clip, target_size, interpolation_mode):
|
| 24 |
+
assert len(target_size) == 2, "target size should be tuple (height, width)"
|
| 25 |
+
# print(target_size)
|
| 26 |
+
return torch.nn.functional.interpolate(
|
| 27 |
+
clip, size=target_size, mode=interpolation_mode, align_corners=False
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
| 32 |
+
"""
|
| 33 |
+
Do spatial cropping and resizing to the video clip
|
| 34 |
+
Args:
|
| 35 |
+
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
| 36 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
| 37 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
| 38 |
+
h (int): Height of the cropped region.
|
| 39 |
+
w (int): Width of the cropped region.
|
| 40 |
+
size (tuple(int, int)): height and width of resized clip
|
| 41 |
+
Returns:
|
| 42 |
+
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
|
| 43 |
+
"""
|
| 44 |
+
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
| 45 |
+
clip = crop(clip, i, j, h, w)
|
| 46 |
+
clip = resize(clip, size, interpolation_mode)
|
| 47 |
+
return clip
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def center_crop(clip, crop_size):
|
| 51 |
+
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
| 52 |
+
h, w = clip.size(-2), clip.size(-1)
|
| 53 |
+
th, tw = crop_size
|
| 54 |
+
assert h >= th and w >= tw, "height and width must be no smaller than crop_size"
|
| 55 |
+
|
| 56 |
+
i = int(round((h - th) / 2.0))
|
| 57 |
+
j = int(round((w - tw) / 2.0))
|
| 58 |
+
return crop(clip, i, j, th, tw)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def to_tensor(clip):
|
| 62 |
+
"""
|
| 63 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
| 64 |
+
permute the dimenions of clip tensor
|
| 65 |
+
Args:
|
| 66 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
|
| 67 |
+
Return:
|
| 68 |
+
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
|
| 69 |
+
"""
|
| 70 |
+
_is_tensor_video_clip(clip)
|
| 71 |
+
if not clip.dtype == torch.uint8:
|
| 72 |
+
raise TypeError(
|
| 73 |
+
f"clip tensor should have data type uint8. Got {str(clip.dtype)}"
|
| 74 |
+
)
|
| 75 |
+
return clip.float().permute(3, 0, 1, 2) / 255.0
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def normalize(clip, mean, std, inplace=False):
|
| 79 |
+
"""
|
| 80 |
+
Args:
|
| 81 |
+
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
| 82 |
+
mean (tuple): pixel RGB mean. Size is (3)
|
| 83 |
+
std (tuple): pixel standard deviation. Size is (3)
|
| 84 |
+
Returns:
|
| 85 |
+
normalized clip (torch.tensor): Size is (C, T, H, W)
|
| 86 |
+
"""
|
| 87 |
+
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
| 88 |
+
if not inplace:
|
| 89 |
+
clip = clip.clone()
|
| 90 |
+
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
| 91 |
+
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
| 92 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
| 93 |
+
return clip
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def hflip(clip):
|
| 97 |
+
"""
|
| 98 |
+
Args:
|
| 99 |
+
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
| 100 |
+
Returns:
|
| 101 |
+
flipped clip (torch.tensor): Size is (C, T, H, W)
|
| 102 |
+
"""
|
| 103 |
+
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
| 104 |
+
return clip.flip(-1)
|
web_demo/utils/load_model.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains functions for loading models."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from os import path
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from network.anomaly_detector_model import AnomalyDetector
|
| 10 |
+
from network.c3d import C3D
|
| 11 |
+
from network.MFNET import MFNET_3D
|
| 12 |
+
from network.resnet import generate_model
|
| 13 |
+
from network.TorchUtils import TorchModel
|
| 14 |
+
from utils.types import Device, FeatureExtractor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_feature_extractor(
|
| 18 |
+
features_method: str, feature_extractor_path: str, device: Device
|
| 19 |
+
) -> FeatureExtractor:
|
| 20 |
+
"""Load feature extractor from given path.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
features_method (str): The feature extractor model type to use. Either c3d | mfnet | r3d101 | r3d152.
|
| 24 |
+
feature_extractor_path (str): Path to the feature extractor model.
|
| 25 |
+
device (Union[torch.device, str]): Device to use for the model.
|
| 26 |
+
|
| 27 |
+
Raises:
|
| 28 |
+
FileNotFoundError: The path to the model does not exist.
|
| 29 |
+
NotImplementedError: The provided feature extractor method is not implemented.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
FeatureExtractor
|
| 33 |
+
"""
|
| 34 |
+
if not path.exists(feature_extractor_path):
|
| 35 |
+
raise FileNotFoundError(
|
| 36 |
+
f"Couldn't find feature extractor {feature_extractor_path}.\n"
|
| 37 |
+
+ r"If you are using resnet, download it first from:\n"
|
| 38 |
+
+ r"r3d101: https://drive.google.com/file/d/1p80RJsghFIKBSLKgtRG94LE38OGY5h4y/view?usp=share_link"
|
| 39 |
+
+ "\n"
|
| 40 |
+
+ r"r3d152: https://drive.google.com/file/d/1irIdC_v7wa-sBpTiBlsMlS7BYNdj4Gr7/view?usp=share_link"
|
| 41 |
+
)
|
| 42 |
+
logging.info(f"Loading feature extractor from {feature_extractor_path}")
|
| 43 |
+
|
| 44 |
+
model: FeatureExtractor
|
| 45 |
+
|
| 46 |
+
if features_method == "c3d":
|
| 47 |
+
model = C3D(pretrained=feature_extractor_path)
|
| 48 |
+
elif features_method == "mfnet":
|
| 49 |
+
model = MFNET_3D()
|
| 50 |
+
model.load_state(state_dict=feature_extractor_path)
|
| 51 |
+
elif features_method == "r3d101":
|
| 52 |
+
model = generate_model(model_depth=101)
|
| 53 |
+
param_dict = torch.load(feature_extractor_path)["state_dict"]
|
| 54 |
+
param_dict.pop("fc.weight")
|
| 55 |
+
param_dict.pop("fc.bias")
|
| 56 |
+
model.load_state_dict(param_dict)
|
| 57 |
+
elif features_method == "r3d152":
|
| 58 |
+
model = generate_model(model_depth=152)
|
| 59 |
+
param_dict = torch.load(feature_extractor_path)["state_dict"]
|
| 60 |
+
param_dict.pop("fc.weight")
|
| 61 |
+
param_dict.pop("fc.bias")
|
| 62 |
+
model.load_state_dict(param_dict)
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError(
|
| 65 |
+
f"Features extraction method {features_method} not implemented"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return model.to(device).eval()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_anomaly_detector(ad_model_path: str, device: Device) -> AnomalyDetector:
|
| 72 |
+
"""Load anomaly detection model from given path.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
ad_model_path (str): Path to the anomaly detection model.
|
| 76 |
+
device (Device): Device to use for the model.
|
| 77 |
+
|
| 78 |
+
Raises:
|
| 79 |
+
FileNotFoundError: The path to the model does not exist.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
AnomalyDetector
|
| 83 |
+
"""
|
| 84 |
+
if not path.exists(ad_model_path):
|
| 85 |
+
raise FileNotFoundError(f"Couldn't find anomaly detector {ad_model_path}.")
|
| 86 |
+
logging.info(f"Loading anomaly detector from {ad_model_path}")
|
| 87 |
+
|
| 88 |
+
anomaly_detector = TorchModel.load_model(ad_model_path).to(device)
|
| 89 |
+
return anomaly_detector.eval()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def load_models(
|
| 93 |
+
feature_extractor_path: str,
|
| 94 |
+
ad_model_path: str,
|
| 95 |
+
features_method: str = "c3d",
|
| 96 |
+
device: Device = "cuda",
|
| 97 |
+
) -> Tuple[AnomalyDetector, FeatureExtractor]:
|
| 98 |
+
"""Loads both feature extractor and anomaly detector from the given paths.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
feature_extractor_path (str): Path of the features extractor weights to load.
|
| 102 |
+
ad_model_path (str): Path of the anomaly detector weights to load.
|
| 103 |
+
features_method (str, optional): Name of the model to use for features extraction.
|
| 104 |
+
Defaults to "c3d".
|
| 105 |
+
device (str, optional): Device to use for the models. Defaults to "cuda".
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Tuple[nn.Module, nn.Module]
|
| 109 |
+
"""
|
| 110 |
+
feature_extractor = load_feature_extractor(
|
| 111 |
+
features_method, feature_extractor_path, device
|
| 112 |
+
)
|
| 113 |
+
anomaly_detector = load_anomaly_detector(ad_model_path, device)
|
| 114 |
+
return anomaly_detector, feature_extractor
|
web_demo/utils/stack.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains an implementation of a stack that fits an online
|
| 2 |
+
container for video clips."""
|
| 3 |
+
|
| 4 |
+
import threading
|
| 5 |
+
from typing import Any, List
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Stack:
|
| 9 |
+
"""Create a stack object with a given maximum size."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, max_size: int) -> None:
|
| 12 |
+
self._stack = []
|
| 13 |
+
self._max_size = max_size
|
| 14 |
+
self._lock = threading.Lock()
|
| 15 |
+
|
| 16 |
+
def put(self, item: Any) -> None:
|
| 17 |
+
"""Put an item into the stack."""
|
| 18 |
+
with self._lock:
|
| 19 |
+
self._stack.append(item)
|
| 20 |
+
if len(self._stack) > self._max_size:
|
| 21 |
+
del self._stack[0]
|
| 22 |
+
|
| 23 |
+
def get(self, size: int = -1) -> List[Any]:
|
| 24 |
+
"""Get an item from the stack."""
|
| 25 |
+
if size == -1:
|
| 26 |
+
size = self._max_size
|
| 27 |
+
return self._stack[-size:]
|
| 28 |
+
|
| 29 |
+
def __len__(self) -> int:
|
| 30 |
+
return len(self._stack)
|
| 31 |
+
|
| 32 |
+
def full(self) -> bool:
|
| 33 |
+
return len(self._stack) == self._max_size
|