File size: 6,708 Bytes
a090915 e2af51e a090915 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import torch
from config import DETECT_MODEL, POSE_MODEL, CONF_THRESHOLD
from utils.gpu import GPUConfigurator
from preprocessing.preprocessor import FramePreprocessor
from data_extraction.interaction_analyzer import InteractionAnalyzer
from data_extraction.person_tracker import PersonTracker
from utils.visualizer import Visualizer
import numpy as np
from ultralytics import YOLO
class VideoFeatureExtractor:
def __init__(self):
self.gpu_config = GPUConfigurator()
self.device = self.gpu_config.device
self.detection_model = YOLO(DETECT_MODEL).to(self.device)
self.pose_model = YOLO(POSE_MODEL).to(self.device)
self.preprocessor = FramePreprocessor()
self.interaction_analyzer = InteractionAnalyzer()
self.person_tracker = PersonTracker()
self.visualizer = Visualizer()
self.conf_threshold = CONF_THRESHOLD
self.prev_poses = None
self.person_tracker.reset()
self.prev_poses = None
def extract_features(self, frame, frame_idx):
"""Extract features from a frame."""
try:
processed_frame, scale_info = self.preprocessor.preprocess_frame(frame)
if processed_frame is None:
return None, frame
frame_tensor = (
torch.from_numpy(processed_frame)
.permute(2, 0, 1)
.unsqueeze(0)
.to(self.device)
)
if frame_idx % 5 == 0:
torch.cuda.empty_cache()
with (
torch.no_grad(),
torch.amp.autocast(device_type="cuda", dtype=torch.float16),
):
det_results = self.detection_model(
frame_tensor, conf=self.conf_threshold, verbose=False
)
pose_results = (
self.pose_model(
frame_tensor, conf=self.conf_threshold, verbose=False
)
if len(det_results[0].boxes) > 0
else []
)
frame_data = {
"frame_index": frame_idx,
"timestamp": frame_idx / 30,
"persons": [],
"objects": [],
"interactions": [],
"resized_width": scale_info.get("resized_size", (0, 0))[1],
"resized_height": scale_info.get("resized_size", (0, 0))[0],
}
# Process detections
person_boxes = []
for result in det_results:
for box in result.boxes:
try:
cls = result.names[int(box.cls[0])]
box_coords = box.xyxy[0].cpu().numpy().tolist()
if cls == "person":
person_boxes.append(box_coords)
else:
frame_data["objects"].append(
{
"class": cls,
"confidence": float(box.conf[0]),
"box": box_coords,
}
)
except Exception as e:
print(f"Detection processing error: {e}")
continue
# Track persons
tracked_persons = self.person_tracker.assign_person_ids(person_boxes)
# Process poses
current_poses = []
if pose_results:
for result in pose_results:
if result.keypoints:
for kpts in result.keypoints:
try:
pose_data = kpts.data[0].cpu().numpy().tolist()
current_poses.append(pose_data)
except Exception as e:
print(f"Pose processing error: {e}")
continue
# Match persons to poses
frame_data["persons"] = []
for i, box in enumerate(person_boxes):
try:
pose = current_poses[i] if i < len(current_poses) else None
if pose is None:
continue
# Find the person ID for this box
person_id = None
for pid, tracked_box in tracked_persons.items():
if np.array_equal(box, tracked_box):
person_id = pid
break
if person_id is None:
continue
frame_data["persons"].append(
{
"person_idx": i,
"person_id": person_id,
"box": box,
"center": [(box[0] + box[2]) / 2, (box[1] + box[3]) / 2],
"keypoints": pose,
}
)
except Exception as e:
print(f"Skipping person {i} due to error: {e}")
continue
# Calculate motion features
motion_features = {
"average_speed": 0,
"motion_intensity": 0,
"sudden_movements": 0,
}
if self.prev_poses and current_poses:
try:
motion_features = (
self.interaction_analyzer.calculate_motion_features(
self.prev_poses, current_poses
)
)
except Exception as e:
print(f"Motion calculation error: {e}")
frame_data["motion_features"] = motion_features
self.prev_poses = current_poses
# Create interactions
frame_data["interactions"] = (
self.interaction_analyzer.calculate_interactions(
person_boxes, current_poses, tracked_persons
)
)
# Add motion features to frame data
annotated_frame = self.visualizer.draw_detections(
frame, det_results, pose_results, scale_info, tracked_persons
)
return frame_data, annotated_frame
except Exception as e:
print(f"Frame {frame_idx} failed completely: {e}")
return None, frame
def reset(self):
"""Reset state for a new video."""
self.person_tracker.reset()
self.prev_poses = None
|