#!/usr/bin/env python3 """ Combine sensor-only NN predictions with transition matrix at inference time. P(y|x,prev) ∝ P_nn(y|x)^α × P_trans(y|prev)^β Tune α,β on validation set. """ import os import sys import json import re import numpy as np import torch import torch.nn as nn from collections import Counter from sklearn.metrics import accuracy_score, f1_score sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.dataset import DATASET_DIR, TRAIN_VOLS, VAL_VOLS, TEST_VOLS from tasks.train_pred_cls import ( ActionPredDataset, TransformerClassifier, ACTION_CLASSES_COARSE, init_classes ) # Initialize global classes init_classes(coarse=True) COARSE_CLASSES = ACTION_CLASSES_COARSE ANNOTATION_DIR = "${PULSE_ROOT}" def get_predictions(model, dataset, device): """Get softmax predictions from model.""" model.eval() loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False) all_probs = [] all_labels = [] all_prev = [] with torch.no_grad(): for batch in loader: features = batch['features'].to(device) mask = batch['mask'].to(device) logits = model(features, mask) # no prev_action probs = torch.softmax(logits, dim=1).cpu().numpy() all_probs.append(probs) all_labels.extend(batch['label']) all_prev.extend(batch['prev_label']) return np.concatenate(all_probs), np.array(all_labels), np.array(all_prev) def compute_transition_matrix(dataset, num_classes): """Compute P(current|prev) from dataset.""" counts = np.zeros((num_classes, num_classes)) for i in range(len(dataset)): sample = dataset[i] prev = sample['prev_label'] curr = sample['label'] counts[prev, curr] += 1 row_sums = counts.sum(axis=1, keepdims=True) row_sums[row_sums == 0] = 1 return counts / row_sums def combined_predict(nn_probs, trans_matrix, prev_labels, alpha, beta): """Combine NN and transition predictions.""" N, C = nn_probs.shape combined = np.zeros_like(nn_probs) for i in range(N): trans_prob = trans_matrix[prev_labels[i]] # Multiplicative combination with temperature p = (nn_probs[i] ** alpha) * (trans_prob ** beta) p_sum = p.sum() if p_sum > 0: combined[i] = p / p_sum else: combined[i] = trans_prob return np.argmax(combined, axis=1) def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Models to evaluate (sensor-only, no prev_action) models_info = [ # (results_dir, modalities, description) ('recog2a', 'imu', 'Recog: IMU'), ('recog2a', 'mocap,emg,eyetrack', 'Recog: MEE'), ('recog2a', 'mocap,emg,imu', 'Recog: MEI'), ('recog_coarse', 'imu', 'Recog10s: IMU'), ('recog_coarse', 'mocap,emg,imu', 'Recog10s: MEI'), ] base_dir = '${PULSE_ROOT}/results' for results_dir, modalities, desc in models_info: mod_str = modalities.replace(',', '-') # Find the model directory result_base = os.path.join(base_dir, results_dir) # Pattern: recog_cls_coarse_{mod_str} model_dir = os.path.join(result_base, f'recog_cls_coarse_{mod_str}') if not os.path.exists(model_dir): print(f" Skip {desc}: {model_dir} not found") continue results_file = os.path.join(model_dir, 'results.json') if not os.path.exists(results_file): continue r = json.load(open(results_file)) args_dict = r['args'] # Recreate datasets mods = modalities.split(',') window_sec = args_dict['window_sec'] downsample = args_dict['downsample'] train_ds = ActionPredDataset( TRAIN_VOLS, mods, window_sec=window_sec, downsample=downsample, coarse=True, mode='recognition') stats = train_ds.get_stats() val_ds = ActionPredDataset( VAL_VOLS, mods, window_sec=window_sec, downsample=downsample, stats=stats, coarse=True, mode='recognition') test_ds = ActionPredDataset( TEST_VOLS, mods, window_sec=window_sec, downsample=downsample, stats=stats, coarse=True, mode='recognition') num_classes = len(COARSE_CLASSES) # Build and load model (without prev_action) model = TransformerClassifier( train_ds.feat_dim, num_classes, d_model=args_dict['hidden_dim'], nhead=4, num_layers=2, dropout=args_dict['dropout'], use_prev_action=False ).to(device) ckpt = torch.load(os.path.join(model_dir, 'model_best.pt'), map_location=device, weights_only=True) model.load_state_dict(ckpt) # Get predictions val_probs, val_labels, val_prev = get_predictions(model, val_ds, device) test_probs, test_labels, test_prev = get_predictions(model, test_ds, device) # Compute transition matrix from train trans_matrix = compute_transition_matrix(train_ds, num_classes) # Baseline: NN only nn_preds = np.argmax(test_probs, axis=1) nn_f1w = f1_score(test_labels, nn_preds, average='weighted', zero_division=0) # Baseline: Transition only trans_preds = np.array([np.argmax(trans_matrix[p]) for p in test_prev]) trans_f1w = f1_score(test_labels, trans_preds, average='weighted', zero_division=0) # Grid search α, β on validation best_val_f1 = -1 best_params = (1.0, 1.0) for alpha in [0.0, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0]: for beta in [0.0, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0]: if alpha == 0 and beta == 0: continue preds = combined_predict(val_probs, trans_matrix, val_prev, alpha, beta) f1w = f1_score(val_labels, preds, average='weighted', zero_division=0) if f1w > best_val_f1: best_val_f1 = f1w best_params = (alpha, beta) # Evaluate on test with best params alpha, beta = best_params combined_preds = combined_predict(test_probs, trans_matrix, test_prev, alpha, beta) comb_f1w = f1_score(test_labels, combined_preds, average='weighted', zero_division=0) comb_acc = accuracy_score(test_labels, combined_preds) # Also try simple additive combination best_val_f1_add = -1 best_w = 0.5 for w in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: preds_add = [] for i in range(len(val_probs)): p = w * val_probs[i] + (1 - w) * trans_matrix[val_prev[i]] preds_add.append(np.argmax(p)) f1w = f1_score(val_labels, preds_add, average='weighted', zero_division=0) if f1w > best_val_f1_add: best_val_f1_add = f1w best_w = w # Test with best w preds_add = [] for i in range(len(test_probs)): p = best_w * test_probs[i] + (1 - best_w) * trans_matrix[test_prev[i]] preds_add.append(np.argmax(p)) add_f1w = f1_score(test_labels, preds_add, average='weighted', zero_division=0) print(f"\n{desc} ({mod_str}):") print(f" NN only: F1w={nn_f1w:.3f}") print(f" Trans only: F1w={trans_f1w:.3f}") print(f" Multiplicative (α={alpha:.1f}, β={beta:.1f}): F1w={comb_f1w:.3f}") print(f" Additive (w={best_w:.1f}): F1w={add_f1w:.3f}") if __name__ == '__main__': main()