| |
| """ |
| Combine sensor-only NN predictions with transition matrix at inference time. |
| P(y|x,prev) ∝ P_nn(y|x)^α × P_trans(y|prev)^β |
| Tune α,β on validation set. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import re |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from collections import Counter |
| from sklearn.metrics import accuracy_score, f1_score |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import DATASET_DIR, TRAIN_VOLS, VAL_VOLS, TEST_VOLS |
| from tasks.train_pred_cls import ( |
| ActionPredDataset, TransformerClassifier, |
| ACTION_CLASSES_COARSE, init_classes |
| ) |
| |
| init_classes(coarse=True) |
| COARSE_CLASSES = ACTION_CLASSES_COARSE |
|
|
| ANNOTATION_DIR = "${PULSE_ROOT}" |
|
|
|
|
| def get_predictions(model, dataset, device): |
| """Get softmax predictions from model.""" |
| model.eval() |
| loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False) |
| all_probs = [] |
| all_labels = [] |
| all_prev = [] |
| with torch.no_grad(): |
| for batch in loader: |
| features = batch['features'].to(device) |
| mask = batch['mask'].to(device) |
| logits = model(features, mask) |
| probs = torch.softmax(logits, dim=1).cpu().numpy() |
| all_probs.append(probs) |
| all_labels.extend(batch['label']) |
| all_prev.extend(batch['prev_label']) |
| return np.concatenate(all_probs), np.array(all_labels), np.array(all_prev) |
|
|
|
|
| def compute_transition_matrix(dataset, num_classes): |
| """Compute P(current|prev) from dataset.""" |
| counts = np.zeros((num_classes, num_classes)) |
| for i in range(len(dataset)): |
| sample = dataset[i] |
| prev = sample['prev_label'] |
| curr = sample['label'] |
| counts[prev, curr] += 1 |
| row_sums = counts.sum(axis=1, keepdims=True) |
| row_sums[row_sums == 0] = 1 |
| return counts / row_sums |
|
|
|
|
| def combined_predict(nn_probs, trans_matrix, prev_labels, alpha, beta): |
| """Combine NN and transition predictions.""" |
| N, C = nn_probs.shape |
| combined = np.zeros_like(nn_probs) |
| for i in range(N): |
| trans_prob = trans_matrix[prev_labels[i]] |
| |
| p = (nn_probs[i] ** alpha) * (trans_prob ** beta) |
| p_sum = p.sum() |
| if p_sum > 0: |
| combined[i] = p / p_sum |
| else: |
| combined[i] = trans_prob |
| return np.argmax(combined, axis=1) |
|
|
|
|
| def main(): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| models_info = [ |
| |
| ('recog2a', 'imu', 'Recog: IMU'), |
| ('recog2a', 'mocap,emg,eyetrack', 'Recog: MEE'), |
| ('recog2a', 'mocap,emg,imu', 'Recog: MEI'), |
| ('recog_coarse', 'imu', 'Recog10s: IMU'), |
| ('recog_coarse', 'mocap,emg,imu', 'Recog10s: MEI'), |
| ] |
|
|
| base_dir = '${PULSE_ROOT}/results' |
|
|
| for results_dir, modalities, desc in models_info: |
| mod_str = modalities.replace(',', '-') |
|
|
| |
| result_base = os.path.join(base_dir, results_dir) |
| |
| model_dir = os.path.join(result_base, f'recog_cls_coarse_{mod_str}') |
| if not os.path.exists(model_dir): |
| print(f" Skip {desc}: {model_dir} not found") |
| continue |
|
|
| results_file = os.path.join(model_dir, 'results.json') |
| if not os.path.exists(results_file): |
| continue |
|
|
| r = json.load(open(results_file)) |
| args_dict = r['args'] |
|
|
| |
| mods = modalities.split(',') |
| window_sec = args_dict['window_sec'] |
| downsample = args_dict['downsample'] |
|
|
| train_ds = ActionPredDataset( |
| TRAIN_VOLS, mods, window_sec=window_sec, |
| downsample=downsample, coarse=True, mode='recognition') |
| stats = train_ds.get_stats() |
| val_ds = ActionPredDataset( |
| VAL_VOLS, mods, window_sec=window_sec, |
| downsample=downsample, stats=stats, coarse=True, mode='recognition') |
| test_ds = ActionPredDataset( |
| TEST_VOLS, mods, window_sec=window_sec, |
| downsample=downsample, stats=stats, coarse=True, mode='recognition') |
|
|
| num_classes = len(COARSE_CLASSES) |
|
|
| |
| model = TransformerClassifier( |
| train_ds.feat_dim, num_classes, |
| d_model=args_dict['hidden_dim'], nhead=4, num_layers=2, |
| dropout=args_dict['dropout'], use_prev_action=False |
| ).to(device) |
| ckpt = torch.load(os.path.join(model_dir, 'model_best.pt'), |
| map_location=device, weights_only=True) |
| model.load_state_dict(ckpt) |
|
|
| |
| val_probs, val_labels, val_prev = get_predictions(model, val_ds, device) |
| test_probs, test_labels, test_prev = get_predictions(model, test_ds, device) |
|
|
| |
| trans_matrix = compute_transition_matrix(train_ds, num_classes) |
|
|
| |
| nn_preds = np.argmax(test_probs, axis=1) |
| nn_f1w = f1_score(test_labels, nn_preds, average='weighted', zero_division=0) |
|
|
| |
| trans_preds = np.array([np.argmax(trans_matrix[p]) for p in test_prev]) |
| trans_f1w = f1_score(test_labels, trans_preds, average='weighted', zero_division=0) |
|
|
| |
| best_val_f1 = -1 |
| best_params = (1.0, 1.0) |
| for alpha in [0.0, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0]: |
| for beta in [0.0, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0]: |
| if alpha == 0 and beta == 0: |
| continue |
| preds = combined_predict(val_probs, trans_matrix, val_prev, alpha, beta) |
| f1w = f1_score(val_labels, preds, average='weighted', zero_division=0) |
| if f1w > best_val_f1: |
| best_val_f1 = f1w |
| best_params = (alpha, beta) |
|
|
| |
| alpha, beta = best_params |
| combined_preds = combined_predict(test_probs, trans_matrix, test_prev, alpha, beta) |
| comb_f1w = f1_score(test_labels, combined_preds, average='weighted', zero_division=0) |
| comb_acc = accuracy_score(test_labels, combined_preds) |
|
|
| |
| best_val_f1_add = -1 |
| best_w = 0.5 |
| for w in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: |
| preds_add = [] |
| for i in range(len(val_probs)): |
| p = w * val_probs[i] + (1 - w) * trans_matrix[val_prev[i]] |
| preds_add.append(np.argmax(p)) |
| f1w = f1_score(val_labels, preds_add, average='weighted', zero_division=0) |
| if f1w > best_val_f1_add: |
| best_val_f1_add = f1w |
| best_w = w |
|
|
| |
| preds_add = [] |
| for i in range(len(test_probs)): |
| p = best_w * test_probs[i] + (1 - best_w) * trans_matrix[test_prev[i]] |
| preds_add.append(np.argmax(p)) |
| add_f1w = f1_score(test_labels, preds_add, average='weighted', zero_division=0) |
|
|
| print(f"\n{desc} ({mod_str}):") |
| print(f" NN only: F1w={nn_f1w:.3f}") |
| print(f" Trans only: F1w={trans_f1w:.3f}") |
| print(f" Multiplicative (α={alpha:.1f}, β={beta:.1f}): F1w={comb_f1w:.3f}") |
| print(f" Additive (w={best_w:.1f}): F1w={add_f1w:.3f}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|