|
|
import torch |
|
|
from config import DETECT_MODEL, POSE_MODEL, CONF_THRESHOLD |
|
|
from utils.gpu import GPUConfigurator |
|
|
from preprocessing.preprocessor import FramePreprocessor |
|
|
from data_extraction.interaction_analyzer import InteractionAnalyzer |
|
|
from data_extraction.person_tracker import PersonTracker |
|
|
from utils.visualizer import Visualizer |
|
|
import numpy as np |
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
class VideoFeatureExtractor: |
|
|
def __init__(self): |
|
|
self.gpu_config = GPUConfigurator() |
|
|
self.device = self.gpu_config.device |
|
|
|
|
|
self.detection_model = YOLO(DETECT_MODEL).to(self.device) |
|
|
self.pose_model = YOLO(POSE_MODEL).to(self.device) |
|
|
|
|
|
self.preprocessor = FramePreprocessor() |
|
|
self.interaction_analyzer = InteractionAnalyzer() |
|
|
self.person_tracker = PersonTracker() |
|
|
self.visualizer = Visualizer() |
|
|
|
|
|
self.conf_threshold = CONF_THRESHOLD |
|
|
self.prev_poses = None |
|
|
|
|
|
self.person_tracker.reset() |
|
|
self.prev_poses = None |
|
|
|
|
|
def extract_features(self, frame, frame_idx): |
|
|
"""Extract features from a frame.""" |
|
|
try: |
|
|
processed_frame, scale_info = self.preprocessor.preprocess_frame(frame) |
|
|
if processed_frame is None: |
|
|
return None, frame |
|
|
|
|
|
frame_tensor = ( |
|
|
torch.from_numpy(processed_frame) |
|
|
.permute(2, 0, 1) |
|
|
.unsqueeze(0) |
|
|
.to(self.device) |
|
|
) |
|
|
|
|
|
if frame_idx % 5 == 0: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
with ( |
|
|
torch.no_grad(), |
|
|
torch.amp.autocast(device_type="cuda", dtype=torch.float16), |
|
|
): |
|
|
det_results = self.detection_model( |
|
|
frame_tensor, conf=self.conf_threshold, verbose=False |
|
|
) |
|
|
pose_results = ( |
|
|
self.pose_model( |
|
|
frame_tensor, conf=self.conf_threshold, verbose=False |
|
|
) |
|
|
if len(det_results[0].boxes) > 0 |
|
|
else [] |
|
|
) |
|
|
|
|
|
frame_data = { |
|
|
"frame_index": frame_idx, |
|
|
"timestamp": frame_idx / 30, |
|
|
"persons": [], |
|
|
"objects": [], |
|
|
"interactions": [], |
|
|
"resized_width": scale_info.get("resized_size", (0, 0))[1], |
|
|
"resized_height": scale_info.get("resized_size", (0, 0))[0], |
|
|
} |
|
|
|
|
|
|
|
|
person_boxes = [] |
|
|
for result in det_results: |
|
|
for box in result.boxes: |
|
|
try: |
|
|
cls = result.names[int(box.cls[0])] |
|
|
box_coords = box.xyxy[0].cpu().numpy().tolist() |
|
|
if cls == "person": |
|
|
person_boxes.append(box_coords) |
|
|
else: |
|
|
frame_data["objects"].append( |
|
|
{ |
|
|
"class": cls, |
|
|
"confidence": float(box.conf[0]), |
|
|
"box": box_coords, |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Detection processing error: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
tracked_persons = self.person_tracker.assign_person_ids(person_boxes) |
|
|
|
|
|
|
|
|
current_poses = [] |
|
|
if pose_results: |
|
|
for result in pose_results: |
|
|
if result.keypoints: |
|
|
for kpts in result.keypoints: |
|
|
try: |
|
|
pose_data = kpts.data[0].cpu().numpy().tolist() |
|
|
current_poses.append(pose_data) |
|
|
except Exception as e: |
|
|
print(f"Pose processing error: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
frame_data["persons"] = [] |
|
|
for i, box in enumerate(person_boxes): |
|
|
try: |
|
|
pose = current_poses[i] if i < len(current_poses) else None |
|
|
if pose is None: |
|
|
continue |
|
|
|
|
|
|
|
|
person_id = None |
|
|
for pid, tracked_box in tracked_persons.items(): |
|
|
if np.array_equal(box, tracked_box): |
|
|
person_id = pid |
|
|
break |
|
|
|
|
|
if person_id is None: |
|
|
continue |
|
|
|
|
|
frame_data["persons"].append( |
|
|
{ |
|
|
"person_idx": i, |
|
|
"person_id": person_id, |
|
|
"box": box, |
|
|
"center": [(box[0] + box[2]) / 2, (box[1] + box[3]) / 2], |
|
|
"keypoints": pose, |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Skipping person {i} due to error: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
motion_features = { |
|
|
"average_speed": 0, |
|
|
"motion_intensity": 0, |
|
|
"sudden_movements": 0, |
|
|
} |
|
|
|
|
|
if self.prev_poses and current_poses: |
|
|
try: |
|
|
motion_features = ( |
|
|
self.interaction_analyzer.calculate_motion_features( |
|
|
self.prev_poses, current_poses |
|
|
) |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Motion calculation error: {e}") |
|
|
|
|
|
frame_data["motion_features"] = motion_features |
|
|
self.prev_poses = current_poses |
|
|
|
|
|
|
|
|
frame_data["interactions"] = ( |
|
|
self.interaction_analyzer.calculate_interactions( |
|
|
person_boxes, current_poses, tracked_persons |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
annotated_frame = self.visualizer.draw_detections( |
|
|
frame, det_results, pose_results, scale_info, tracked_persons |
|
|
) |
|
|
|
|
|
return frame_data, annotated_frame |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Frame {frame_idx} failed completely: {e}") |
|
|
return None, frame |
|
|
|
|
|
def reset(self): |
|
|
"""Reset state for a new video.""" |
|
|
self.person_tracker.reset() |
|
|
self.prev_poses = None |
|
|
|