Chordia / src /utils /stress_calculator.py
Corolin's picture
first commit
0a6452f
"""
压力计算器
Stress Calculator based on PAD changes
压力通过 PAD 状态变化动态计算,而不是作为模型输出。
理论依据:
- Stress_Change ∝ α·(-ΔP) + β·(ΔA) + γ·(-ΔD)
- 不开心(-P)、紧张(+A)、失去掌控(-D)都会增加压力
"""
import torch
import numpy as np
from typing import Union, Tuple
from loguru import logger
class StressCalculator:
"""
压力计算器
根据 PAD 状态变化量计算压力变化量
"""
def __init__(
self,
alpha: float = 1.0, # Pleasure 权重(不开心增加压力)
beta: float = 0.8, # Arousal 权重(紧张增加压力)
gamma: float = 0.6 # Dominance 权重(失去掌控增加压力)
):
"""
初始化压力计算器
Args:
alpha: Pleasure 维度的权重系数
beta: Arousal 维度的权重系数
gamma: Dominance 维度的权重系数
"""
self.alpha = alpha
self.beta = beta
self.gamma = gamma
logger.info(f"压力计算器初始化: α={alpha}, β={beta}, γ={gamma}")
def compute_stress_change(
self,
delta_pad: Union[torch.Tensor, np.ndarray]
) -> Union[torch.Tensor, np.ndarray]:
"""
根据 PAD 变化量计算压力变化量
公式:Stress_Change = α·(-ΔP) + β·(ΔA) + γ·(-ΔD)
Args:
delta_pad: PAD 变化量,形状为 (batch_size, 3) 或 (3,)
顺序为 [ΔP, ΔA, ΔD]
Returns:
压力变化量,形状为 (batch_size, 1) 或 (1,)
"""
is_tensor = isinstance(delta_pad, torch.Tensor)
if is_tensor:
# PyTorch 张量计算
if delta_pad.dim() == 1:
delta_pad = delta_pad.unsqueeze(0)
delta_p, delta_a, delta_d = delta_pad[:, 0], delta_pad[:, 1], delta_pad[:, 2]
# 压力计算公式
stress_change = self.alpha * (-delta_p) + self.beta * (delta_a) + self.gamma * (-delta_d)
return stress_change.unsqueeze(1) # (batch_size, 1)
else:
# NumPy 数组计算
if delta_pad.ndim == 1:
delta_pad = delta_pad.reshape(1, -1)
delta_p, delta_a, delta_d = delta_pad[:, 0], delta_pad[:, 1], delta_pad[:, 2]
# 压力计算公式
stress_change = self.alpha * (-delta_p) + self.beta * (delta_a) + self.gamma * (-delta_d)
return stress_change.reshape(-1, 1) # (batch_size, 1)
def compute_stress_from_base_stress(
self,
delta_pad: Union[torch.Tensor, np.ndarray],
base_stress: Union[torch.Tensor, np.ndarray]
) -> Union[torch.Tensor, np.ndarray]:
"""
从基准压力和变化量计算新压力
Args:
delta_pad: PAD 变化量,形状为 (batch_size, 3)
base_stress: 基准压力,形状为 (batch_size, 1) 或 (1,)
Returns:
新压力,形状与 base_stress 相同
"""
stress_change = self.compute_stress_change(delta_pad)
if isinstance(stress_change, torch.Tensor):
if isinstance(base_stress, torch.Tensor):
if base_stress.dim() == 1:
base_stress = base_stress.unsqueeze(1)
return base_stress + stress_change
else:
if isinstance(base_stress, np.ndarray):
if base_stress.ndim == 1:
base_stress = base_stress.reshape(-1, 1)
return base_stress + stress_change
def get_formula_description(self) -> str:
"""返回压力计算公式描述"""
return (
f"Stress_Change = {self.alpha}·(-ΔP) + {self.beta}·(ΔA) + {self.gamma}·(-ΔD)\n"
f"解释:不开心(-ΔP)、紧张(+ΔA)、失去掌控(-ΔD)都会增加压力"
)
# 创建默认的压力计算器实例
_default_calculator = None
def get_default_calculator() -> StressCalculator:
"""获取默认的压力计算器(单例模式)"""
global _default_calculator
if _default_calculator is None:
_default_calculator = StressCalculator(
alpha=1.0,
beta=0.8,
gamma=0.6
)
return _default_calculator
def compute_stress_from_pad_change(
delta_pad: Union[torch.Tensor, np.ndarray],
alpha: float = 1.0,
beta: float = 0.8,
gamma: float = 0.6
) -> Union[torch.Tensor, np.ndarray]:
"""
便捷函数:从 PAD 变化量计算压力变化量
Args:
delta_pad: PAD 变化量,形状为 (batch_size, 3)
alpha: Pleasure 权重
beta: Arousal 权重
gamma: Dominance 权重
Returns:
压力变化量,形状为 (batch_size, 1)
"""
calculator = StressCalculator(alpha, beta, gamma)
return calculator.compute_stress_change(delta_pad)
def compute_stress(
current_pad: Union[torch.Tensor, np.ndarray],
base_pad: Union[torch.Tensor, np.ndarray],
alpha: float = 1.0,
beta: float = 0.8,
gamma: float = 0.6
) -> Union[torch.Tensor, np.ndarray]:
"""
便捷函数:从当前 PAD 和基准 PAD 计算压力
Args:
current_pad: 当前 PAD,形状为 (batch_size, 3)
base_pad: 基准 PAD,形状为 (batch_size, 3)
alpha: Pleasure 权重
beta: Arousal 权重
gamma: Dominance 权重
Returns:
压力值,形状为 (batch_size, 1)
"""
delta_pad = current_pad - base_pad
return compute_stress_from_pad_change(delta_pad, alpha, beta, gamma)
if __name__ == "__main__":
# 测试代码
print("压力计算器测试")
print("=" * 60)
# 创建测试数据
delta_pad = np.array([
[0.1, 0.2, -0.1], # 开心、稍微紧张、失去一些掌控 → 压力应该减少
[-0.2, 0.3, -0.2], # 不开心、紧张、失去掌控 → 压力应该增加
[0.0, 0.0, 0.0], # 没有变化 → 压力不变
])
calculator = StressCalculator(alpha=1.0, beta=0.8, gamma=0.6)
print("公式:", calculator.get_formula_description())
print("\n测试数据:")
print(delta_pad)
stress_change = calculator.compute_stress_change(delta_pad)
print("\n计算出的压力变化量:")
print(stress_change)
print("\n解释:")
for i, (dp, da, dd, ds) in enumerate(delta_pad):
print(f" 样本{i+1}: ΔP={dp:.2f}, ΔA={da:.2f}, ΔD={dd:.2f} → ΔStress={ds[0]:.4f}")