|
|
import torch |
|
|
import numpy as np |
|
|
from typing import List, Tuple |
|
|
|
|
|
class PoseClassifier: |
|
|
"""将pose参数分类为前后左右四个类别,正确使用rotation数据判断转弯""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self.FORWARD = 0 |
|
|
self.BACKWARD = 1 |
|
|
self.LEFT_TURN = 2 |
|
|
self.RIGHT_TURN = 3 |
|
|
|
|
|
self.class_names = ['forward', 'backward', 'left_turn', 'right_turn'] |
|
|
|
|
|
def classify_pose_sequence(self, pose_sequence: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
对pose序列进行分类,基于相对于reference的pose变化 |
|
|
Args: |
|
|
pose_sequence: [num_frames, 7] (relative_translation + relative_quaternion) |
|
|
这里的pose都是相对于reference帧的相对变换 |
|
|
Returns: |
|
|
classifications: [num_frames] 类别标签 |
|
|
""" |
|
|
|
|
|
translations = pose_sequence[:, :3] |
|
|
rotations = pose_sequence[:, 3:7] |
|
|
|
|
|
|
|
|
classifications = [] |
|
|
for i in range(len(pose_sequence)): |
|
|
|
|
|
relative_translation = translations[i] |
|
|
relative_rotation = rotations[i] |
|
|
|
|
|
class_label = self._classify_single_pose(relative_translation, relative_rotation) |
|
|
classifications.append(class_label) |
|
|
|
|
|
return torch.tensor(classifications, dtype=torch.long) |
|
|
|
|
|
def _classify_single_pose(self, relative_translation: torch.Tensor, |
|
|
relative_rotation: torch.Tensor) -> int: |
|
|
""" |
|
|
对单个pose进行分类,基于相对于reference的变化 |
|
|
Args: |
|
|
relative_translation: [3] 相对于reference的位移变化 |
|
|
relative_rotation: [4] 相对于reference的旋转四元数 [w, x, y, z] |
|
|
""" |
|
|
|
|
|
yaw_angle = self._quaternion_to_yaw(relative_rotation) |
|
|
|
|
|
|
|
|
forward_movement = -relative_translation[0].item() |
|
|
|
|
|
|
|
|
yaw_threshold = 0.05 |
|
|
movement_threshold = 0.01 |
|
|
|
|
|
|
|
|
if abs(yaw_angle) > yaw_threshold: |
|
|
if yaw_angle > 0: |
|
|
return self.LEFT_TURN |
|
|
else: |
|
|
return self.RIGHT_TURN |
|
|
|
|
|
|
|
|
if abs(forward_movement) > movement_threshold: |
|
|
if forward_movement > 0: |
|
|
return self.FORWARD |
|
|
else: |
|
|
return self.BACKWARD |
|
|
|
|
|
|
|
|
return self.FORWARD |
|
|
|
|
|
def _quaternion_to_yaw(self, q: torch.Tensor) -> float: |
|
|
""" |
|
|
从四元数提取yaw角度(绕z轴旋转) |
|
|
Args: |
|
|
q: [4] 四元数 [w, x, y, z] |
|
|
Returns: |
|
|
yaw: yaw角度(弧度) |
|
|
""" |
|
|
try: |
|
|
|
|
|
q_np = q.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
norm = np.linalg.norm(q_np) |
|
|
if norm > 1e-8: |
|
|
q_np = q_np / norm |
|
|
else: |
|
|
|
|
|
return 0.0 |
|
|
|
|
|
w, x, y, z = q_np |
|
|
|
|
|
|
|
|
yaw = np.arctan2(2.0 * (w*z + x*y), 1.0 - 2.0 * (y*y + z*z)) |
|
|
|
|
|
return float(yaw) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error computing yaw from quaternion: {e}") |
|
|
return 0.0 |
|
|
|
|
|
def create_class_embedding(self, class_labels: torch.Tensor, embed_dim: int = 512) -> torch.Tensor: |
|
|
""" |
|
|
为类别标签创建embedding |
|
|
Args: |
|
|
class_labels: [num_frames] 类别标签 |
|
|
embed_dim: embedding维度 |
|
|
Returns: |
|
|
embeddings: [num_frames, embed_dim] |
|
|
""" |
|
|
num_classes = 4 |
|
|
num_frames = len(class_labels) |
|
|
|
|
|
|
|
|
|
|
|
direction_vectors = torch.tensor([ |
|
|
[1.0, 0.0, 0.0, 0.0], |
|
|
[-1.0, 0.0, 0.0, 0.0], |
|
|
[0.0, 1.0, 0.0, 0.0], |
|
|
[0.0, -1.0, 0.0, 0.0], |
|
|
], dtype=torch.float32) |
|
|
|
|
|
|
|
|
one_hot = torch.zeros(num_frames, num_classes) |
|
|
one_hot.scatter_(1, class_labels.unsqueeze(1), 1) |
|
|
|
|
|
|
|
|
base_embeddings = one_hot @ direction_vectors |
|
|
|
|
|
|
|
|
if embed_dim > 4: |
|
|
|
|
|
expand_matrix = torch.randn(4, embed_dim) * 0.1 |
|
|
|
|
|
expand_matrix[:4, :4] = torch.eye(4) |
|
|
embeddings = base_embeddings @ expand_matrix |
|
|
else: |
|
|
embeddings = base_embeddings[:, :embed_dim] |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def get_class_name(self, class_id: int) -> str: |
|
|
"""获取类别名称""" |
|
|
return self.class_names[class_id] |
|
|
|
|
|
def analyze_pose_sequence(self, pose_sequence: torch.Tensor) -> dict: |
|
|
""" |
|
|
分析pose序列,返回详细的统计信息 |
|
|
Args: |
|
|
pose_sequence: [num_frames, 7] (translation + quaternion) |
|
|
Returns: |
|
|
analysis: 包含统计信息的字典 |
|
|
""" |
|
|
classifications = self.classify_pose_sequence(pose_sequence) |
|
|
|
|
|
|
|
|
class_counts = torch.bincount(classifications, minlength=4) |
|
|
|
|
|
|
|
|
motion_segments = [] |
|
|
if len(classifications) > 0: |
|
|
current_class = classifications[0].item() |
|
|
segment_start = 0 |
|
|
|
|
|
for i in range(1, len(classifications)): |
|
|
if classifications[i].item() != current_class: |
|
|
motion_segments.append({ |
|
|
'class': self.get_class_name(current_class), |
|
|
'start_frame': segment_start, |
|
|
'end_frame': i-1, |
|
|
'duration': i - segment_start |
|
|
}) |
|
|
current_class = classifications[i].item() |
|
|
segment_start = i |
|
|
|
|
|
|
|
|
motion_segments.append({ |
|
|
'class': self.get_class_name(current_class), |
|
|
'start_frame': segment_start, |
|
|
'end_frame': len(classifications)-1, |
|
|
'duration': len(classifications) - segment_start |
|
|
}) |
|
|
|
|
|
|
|
|
translations = pose_sequence[:, :3] |
|
|
if len(translations) > 1: |
|
|
|
|
|
total_distance = torch.norm(translations[-1] - translations[0]) |
|
|
else: |
|
|
total_distance = torch.tensor(0.0) |
|
|
|
|
|
analysis = { |
|
|
'total_frames': len(pose_sequence), |
|
|
'class_distribution': { |
|
|
self.get_class_name(i): count.item() |
|
|
for i, count in enumerate(class_counts) |
|
|
}, |
|
|
'motion_segments': motion_segments, |
|
|
'total_distance': total_distance.item(), |
|
|
'classifications': classifications |
|
|
} |
|
|
|
|
|
return analysis |
|
|
|
|
|
def debug_single_pose(self, relative_translation: torch.Tensor, |
|
|
relative_rotation: torch.Tensor) -> dict: |
|
|
""" |
|
|
调试单个pose的分类过程 |
|
|
Args: |
|
|
relative_translation: [3] 相对位移 |
|
|
relative_rotation: [4] 相对旋转四元数 |
|
|
Returns: |
|
|
debug_info: 调试信息字典 |
|
|
""" |
|
|
yaw_angle = self._quaternion_to_yaw(relative_rotation) |
|
|
forward_movement = -relative_translation[0].item() |
|
|
|
|
|
yaw_threshold = 0.05 |
|
|
movement_threshold = 0.01 |
|
|
|
|
|
classification = self._classify_single_pose(relative_translation, relative_rotation) |
|
|
|
|
|
debug_info = { |
|
|
'relative_translation': relative_translation.tolist(), |
|
|
'relative_rotation': relative_rotation.tolist(), |
|
|
'yaw_angle_rad': yaw_angle, |
|
|
'yaw_angle_deg': np.degrees(yaw_angle), |
|
|
'forward_movement': forward_movement, |
|
|
'yaw_threshold': yaw_threshold, |
|
|
'movement_threshold': movement_threshold, |
|
|
'classification': self.get_class_name(classification), |
|
|
'classification_id': classification, |
|
|
'decision_process': { |
|
|
'abs_yaw_exceeds_threshold': abs(yaw_angle) > yaw_threshold, |
|
|
'abs_movement_exceeds_threshold': abs(forward_movement) > movement_threshold, |
|
|
'yaw_direction': 'left' if yaw_angle > 0 else 'right' if yaw_angle < 0 else 'none', |
|
|
'movement_direction': 'forward' if forward_movement > 0 else 'backward' if forward_movement < 0 else 'none' |
|
|
} |
|
|
} |
|
|
|
|
|
return debug_info |