| |
| """ |
| Action Prediction via Verb-Category Classification. |
| |
| Instead of generating free-form text (which fails with ~2000 unique labels / ~1600 samples), |
| we classify the next action into ~20 verb categories extracted from text annotations. |
| |
| Architecture: Transformer encoder (proven in exp1 with F1=0.771 on scene recognition). |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import math |
| import re |
| import random |
| import argparse |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from sklearn.metrics import accuracy_score, f1_score, classification_report |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import ( |
| DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, VAL_VOLS, TEST_VOLS, |
| load_modality_array, |
| ) |
|
|
| ANNOTATION_DIR = "${PULSE_ROOT}" |
|
|
|
|
| |
| |
| |
|
|
| VERB_MAP_RULES = [ |
| |
| ('抓取', '抓取'), ('拿起', '抓取'), ('拿出', '抓取'), |
| ('从.*取出', '抓取'), ('从.*抓取', '抓取'), ('从.*提取', '抓取'), |
| ('从.*取下', '抓取'), ('从.*抽出', '抓取'), ('从.*拔出', '抓取'), |
| ('双手抓', '抓取'), ('双手协.*抓', '抓取'), ('分别抓', '抓取'), |
| ('伸手', '抓取'), |
| |
| ('放置', '放置'), ('放回', '放置'), ('放入', '放置'), |
| ('丢弃', '放置'), ('归还', '放置'), |
| |
| ('移动', '移动'), ('搬运', '移动'), ('移开', '移动'), |
| ('推入', '移动'), ('推动', '移动'), ('拉开', '移动'), ('拉出', '移动'), |
| ('搬移', '移动'), ('转移', '移动'), ('递送', '移动'), |
| ('交接', '移动'), ('传递', '移动'), ('滑动', '移动'), |
| ('分别持握.*移', '移动'), |
| |
| ('调整', '调整'), ('对齐', '调整'), ('微调', '调整'), |
| ('重新', '调整'), ('摆正', '调整'), ('归位', '调整'), |
| |
| ('折叠', '折叠'), ('二次折叠', '折叠'), ('对折', '折叠'), |
| |
| ('展开', '展开'), ('打开', '展开'), ('揭开', '展开'), |
| ('拆开', '展开'), ('撕开', '展开'), ('掀开', '展开'), |
| |
| ('擦拭', '擦拭'), ('抚平', '擦拭'), ('清洁', '擦拭'), ('清理', '擦拭'), |
| |
| ('旋转', '旋转'), ('旋紧', '旋转'), ('旋开', '旋转'), |
| ('拧开', '旋转'), ('拧紧', '旋转'), |
| |
| ('提起', '提起'), ('抬起', '提起'), ('举起', '提起'), ('翻起', '提起'), |
| |
| ('倾倒', '倾倒'), ('装填', '倾倒'), ('倒入', '倾倒'), ('倒出', '倾倒'), |
| ('舀取', '倾倒'), ('注入', '倾倒'), ('从.*舀', '倾倒'), |
| |
| ('整理', '整理'), ('堆叠', '整理'), ('排列', '整理'), |
| ('收纳', '整理'), ('码放', '整理'), |
| |
| ('检查', '检查'), ('确认', '检查'), ('查看', '检查'), |
| ('保持', '检查'), ('观察', '检查'), |
| |
| ('按压', '按压'), ('压实', '按压'), ('压平', '按压'), |
| |
| ('盖上', '盖合'), ('关闭', '盖合'), ('密封', '盖合'), ('合上', '盖合'), |
| ('封口', '盖合'), ('封箱', '盖合'), |
| |
| ('分离', '分离'), ('分开', '分离'), |
| |
| ('粘贴', '粘贴'), ('固定', '粘贴'), ('贴上', '粘贴'), ('加固', '粘贴'), |
| |
| ('释放', '释放'), |
| |
| ('使用', '操作'), ('操作', '操作'), ('搅拌', '操作'), |
| ('切割', '操作'), ('切断', '操作'), ('剪断', '操作'), ('修剪', '操作'), |
| |
| ('翻转', '翻转'), ('翻面', '翻转'), |
| |
| ('准备', '其他'), ('完成', '其他'), ('最终', '其他'), |
| |
| ('将.*放', '放置'), ('将.*装', '倾倒'), ('将.*倒', '倾倒'), |
| ('将.*移', '移动'), ('将.*折', '折叠'), ('将.*盖', '盖合'), |
| ('将.*展', '展开'), ('将.*提', '提起'), ('将.*拉', '移动'), |
| ('将.*推', '移动'), ('将.*擦', '擦拭'), ('将.*抓', '抓取'), |
| ('将.*旋', '旋转'), ('将.*拧', '旋转'), ('将.*整', '整理'), |
| ('将.*调', '调整'), ('将.*对', '调整'), ('将.*贴', '粘贴'), |
| ('将.*翻', '翻转'), ('将.*压', '按压'), ('将.*插', '操作'), |
| ('将.*切', '操作'), ('将.*固', '粘贴'), ('将.*封', '盖合'), |
| ('将', '操作'), |
| ('双手', '操作'), ('再次', '调整'), |
| ] |
|
|
| ACTION_CLASSES_FINE = [ |
| '抓取', '放置', '移动', '调整', '擦拭', '折叠', '旋转', |
| '操作', '盖合', '整理', '展开', '倾倒', '检查', '提起', |
| '释放', '粘贴', '分离', '按压', '翻转', '其他', |
| ] |
|
|
| |
| ACTION_CLASSES_COARSE = [ |
| '抓取', '放置', '移动', '调整', '擦拭', '折叠', '旋转', '其他', |
| ] |
| FINE_TO_COARSE = { |
| '抓取': '抓取', '放置': '放置', '移动': '移动', |
| '调整': '调整', '整理': '调整', |
| '擦拭': '擦拭', |
| '折叠': '折叠', '展开': '折叠', |
| '旋转': '旋转', '盖合': '旋转', |
| '操作': '其他', '倾倒': '其他', '检查': '其他', '提起': '其他', |
| '释放': '其他', '粘贴': '其他', '分离': '其他', '按压': '其他', |
| '翻转': '其他', '其他': '其他', |
| } |
|
|
| |
| ACTION_CLASSES = None |
| NUM_ACTION_CLASSES = None |
| ACTION_TO_IDX = None |
|
|
|
|
| def init_classes(coarse=False): |
| global ACTION_CLASSES, NUM_ACTION_CLASSES, ACTION_TO_IDX |
| if coarse: |
| ACTION_CLASSES = ACTION_CLASSES_COARSE |
| else: |
| ACTION_CLASSES = ACTION_CLASSES_FINE |
| NUM_ACTION_CLASSES = len(ACTION_CLASSES) |
| ACTION_TO_IDX = {c: i for i, c in enumerate(ACTION_CLASSES)} |
|
|
|
|
| def text_to_action_class(text, coarse=False): |
| fine_label = '其他' |
| for pattern, label in VERB_MAP_RULES: |
| if re.search(pattern, text): |
| fine_label = label |
| break |
| if coarse: |
| return FINE_TO_COARSE.get(fine_label, '其他') |
| return fine_label |
|
|
|
|
| def set_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def parse_timestamp(ts_str): |
| parts = ts_str.strip().split(':') |
| if len(parts) == 2: |
| return int(parts[0]) * 60 + int(parts[1]) |
| elif len(parts) == 3: |
| return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2]) |
| return 0 |
|
|
|
|
| |
| |
| |
|
|
| class ActionPredDataset(Dataset): |
| def __init__(self, volunteers, modalities, |
| window_sec=15.0, downsample=5, sampling_rate=100, stats=None, |
| coarse=False, mode='prediction'): |
| self._feat_dim = None |
| self.mode = mode |
| raw_samples = [] |
| all_features_for_stats = [] |
| window_frames = int(window_sec * sampling_rate / downsample) |
| self.window_frames = window_frames |
|
|
| for vol in volunteers: |
| vol_dir = os.path.join(DATASET_DIR, vol) |
| if not os.path.isdir(vol_dir): |
| continue |
| for scenario in sorted(os.listdir(vol_dir)): |
| scenario_dir = os.path.join(vol_dir, scenario) |
| if not os.path.isdir(scenario_dir): |
| continue |
| meta_path = os.path.join(scenario_dir, 'alignment_metadata.json') |
| if not os.path.exists(meta_path): |
| continue |
| with open(meta_path) as f: |
| meta = json.load(f) |
| if not set(modalities).issubset(set(meta['modalities'])): |
| continue |
|
|
| parts = [] |
| for mod in modalities: |
| filepath = os.path.join(scenario_dir, MODALITY_FILES[mod]) |
| arr = load_modality_array(filepath, mod) |
| parts.append(arr) |
| min_len = min(p.shape[0] for p in parts) |
| features = np.concatenate([p[:min_len] for p in parts], axis=1) |
| features = features[::downsample] |
| if self._feat_dim is None: |
| self._feat_dim = features.shape[1] |
| all_features_for_stats.append(features) |
|
|
| ann_path = os.path.join(ANNOTATION_DIR, vol, f"{scenario}.json") |
| if not os.path.exists(ann_path): |
| continue |
| with open(ann_path) as f: |
| ann = json.load(f) |
| segments = [] |
| for seg in ann.get('segments', []): |
| m = re.match(r'(\d+:\d+(?::\d+)?)\s*-\s*(\d+:\d+(?::\d+)?)', |
| seg['timestamp']) |
| if not m: |
| continue |
| start_sec = parse_timestamp(m.group(1)) |
| end_sec = parse_timestamp(m.group(2)) |
| start_frame = int(start_sec * sampling_rate / downsample) |
| end_frame = int(end_sec * sampling_rate / downsample) |
| action_cls = text_to_action_class(seg['task'], coarse=coarse) |
| label_idx = ACTION_TO_IDX[action_cls] |
| segments.append((start_frame, end_frame, label_idx, seg['task'])) |
|
|
| if mode == 'prediction' and len(segments) < 2: |
| continue |
| if mode == 'recognition' and len(segments) < 1: |
| continue |
|
|
| T_total = features.shape[0] |
|
|
| if mode == 'prediction': |
| |
| for i in range(1, len(segments)): |
| boundary = segments[i][0] |
| if boundary > T_total: |
| break |
| end = boundary |
| start = max(0, end - window_frames) |
| window = features[start:end] |
| if window.shape[0] == 0: |
| continue |
| actual_len = window.shape[0] |
| if actual_len < window_frames: |
| pad = np.zeros((window_frames - actual_len, self._feat_dim)) |
| window = np.concatenate([pad, window], axis=0) |
| mask = np.zeros(window_frames, dtype=np.float32) |
| mask[window_frames - actual_len:] = 1.0 |
| else: |
| mask = np.ones(window_frames, dtype=np.float32) |
| prev_label = segments[i - 1][2] |
| raw_samples.append(( |
| window.astype(np.float32), mask, |
| segments[i][2], segments[i][3], prev_label |
| )) |
| else: |
| |
| for i in range(len(segments)): |
| seg_start = segments[i][0] |
| seg_end = min(segments[i][1], T_total) |
| if seg_start >= seg_end: |
| continue |
| window = features[seg_start:seg_end] |
| if window.shape[0] == 0: |
| continue |
| actual_len = window.shape[0] |
| if actual_len > window_frames: |
| |
| offset = (actual_len - window_frames) // 2 |
| window = window[offset:offset + window_frames] |
| actual_len = window_frames |
| if actual_len < window_frames: |
| pad = np.zeros((window_frames - actual_len, self._feat_dim)) |
| window = np.concatenate([pad, window], axis=0) |
| mask = np.zeros(window_frames, dtype=np.float32) |
| mask[window_frames - actual_len:] = 1.0 |
| else: |
| mask = np.ones(window_frames, dtype=np.float32) |
| prev_label = segments[i - 1][2] if i > 0 else segments[i][2] |
| raw_samples.append(( |
| window.astype(np.float32), mask, |
| segments[i][2], segments[i][3], prev_label |
| )) |
|
|
| |
| if stats is not None: |
| self.mean, self.std = stats |
| else: |
| if all_features_for_stats: |
| cat = np.concatenate(all_features_for_stats, axis=0).astype(np.float64) |
| self.mean = np.mean(cat, axis=0, keepdims=True) |
| self.std = np.std(cat, axis=0, keepdims=True) |
| self.std[self.std < 1e-8] = 1.0 |
| else: |
| d = self._feat_dim or 1 |
| self.mean = np.zeros((1, d)) |
| self.std = np.ones((1, d)) |
|
|
| self.data = [] |
| self.labels = [] |
| self.texts = [] |
| self.masks = [] |
| self.prev_labels = [] |
| for x, mask, label, text, prev_label in raw_samples: |
| self.data.append(((x - self.mean) / self.std).astype(np.float32)) |
| self.masks.append(mask) |
| self.labels.append(label) |
| self.texts.append(text) |
| self.prev_labels.append(prev_label) |
|
|
| from collections import Counter |
| dist = Counter(self.labels) |
| print(f" {len(self.data)} samples, feat_dim={self._feat_dim}, " |
| f"window={window_frames}f ({window_sec}s), " |
| f"classes={len(dist)}", flush=True) |
| for cls_name in ACTION_CLASSES: |
| idx = ACTION_TO_IDX[cls_name] |
| print(f" {cls_name}: {dist.get(idx, 0)}", flush=True) |
|
|
| def get_stats(self): |
| return (self.mean, self.std) |
|
|
| @property |
| def feat_dim(self): |
| return self._feat_dim |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| return { |
| 'features': torch.from_numpy(self.data[idx]), |
| 'mask': torch.from_numpy(self.masks[idx]), |
| 'label': self.labels[idx], |
| 'prev_label': self.prev_labels[idx], |
| } |
|
|
|
|
| |
| |
| |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, dropout=0.1, max_len=5000): |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
| pe = torch.zeros(max_len, d_model) |
| pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div = torch.exp(torch.arange(0, d_model, 2).float() * |
| (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
| def forward(self, x): |
| return self.dropout(x + self.pe[:, :x.size(1)]) |
|
|
|
|
| class TransformerClassifier(nn.Module): |
| def __init__(self, input_dim, num_classes, d_model=64, nhead=4, |
| num_layers=2, dropout=0.2, use_prev_action=False): |
| super().__init__() |
| self.use_prev_action = use_prev_action |
| self.proj = nn.Linear(input_dim, d_model) |
| self.pos = PositionalEncoding(d_model, dropout) |
| layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, |
| dropout=dropout, batch_first=True) |
| self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers) |
| self.attn_pool = nn.Linear(d_model, 1) |
|
|
| |
| if use_prev_action: |
| self.action_embed = nn.Embedding(num_classes, d_model) |
| cls_input_dim = d_model * 2 |
| else: |
| cls_input_dim = d_model |
|
|
| self.classifier = nn.Sequential( |
| nn.LayerNorm(cls_input_dim), |
| nn.Dropout(dropout), |
| nn.Linear(cls_input_dim, num_classes), |
| ) |
| self.output_dim = d_model |
|
|
| def forward(self, x, mask=None, prev_action=None): |
| x = self.pos(self.proj(x)) |
| if mask is not None: |
| src_key_padding_mask = (mask == 0) |
| else: |
| src_key_padding_mask = None |
| x = self.encoder(x, src_key_padding_mask=src_key_padding_mask) |
|
|
| |
| attn_w = self.attn_pool(x).squeeze(-1) |
| if mask is not None: |
| attn_w = attn_w.masked_fill(mask == 0, -1e9) |
| attn_w = torch.softmax(attn_w, dim=1) |
| pooled = (x * attn_w.unsqueeze(-1)).sum(dim=1) |
|
|
| if self.use_prev_action and prev_action is not None: |
| act_emb = self.action_embed(prev_action) |
| pooled = torch.cat([pooled, act_emb], dim=1) |
|
|
| return self.classifier(pooled) |
|
|
|
|
| |
| |
| |
|
|
| def train_epoch(model, loader, optimizer, criterion, device, |
| augment=False, noise_std=0.1, time_mask_ratio=0.1): |
| model.train() |
| total_loss, correct, total = 0, 0, 0 |
| for batch in loader: |
| features = batch['features'].to(device) |
| mask = batch['mask'].to(device) |
| labels = torch.tensor(batch['label'], dtype=torch.long).to(device) |
| prev_action = torch.tensor(batch['prev_label'], dtype=torch.long).to(device) |
|
|
| if augment: |
| noise = torch.randn_like(features) * noise_std |
| features = features + noise * mask.unsqueeze(-1) |
| B, T, C = features.shape |
| mask_len = int(T * time_mask_ratio) |
| if mask_len > 0: |
| for i in range(B): |
| valid_len = mask[i].sum().int().item() |
| if valid_len > mask_len: |
| valid_start = T - valid_len |
| start = random.randint(0, valid_len - mask_len) |
| features[i, valid_start + start:valid_start + start + mask_len, :] = 0.0 |
|
|
| optimizer.zero_grad() |
| logits = model(features, mask, prev_action=prev_action) |
| loss = criterion(logits, labels) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| total_loss += loss.item() * features.size(0) |
| preds = logits.argmax(dim=1) |
| correct += (preds == labels).sum().item() |
| total += features.size(0) |
| return total_loss / max(total, 1), correct / max(total, 1) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, criterion, device): |
| model.eval() |
| total_loss, all_preds, all_labels = 0, [], [] |
| n = 0 |
| for batch in loader: |
| features = batch['features'].to(device) |
| mask = batch['mask'].to(device) |
| labels = torch.tensor(batch['label'], dtype=torch.long).to(device) |
| prev_action = torch.tensor(batch['prev_label'], dtype=torch.long).to(device) |
|
|
| logits = model(features, mask, prev_action=prev_action) |
| loss = criterion(logits, labels) |
| total_loss += loss.item() * features.size(0) |
| n += features.size(0) |
|
|
| preds = logits.argmax(dim=1) |
| all_preds.extend(preds.cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
|
|
| all_preds = np.array(all_preds) |
| all_labels = np.array(all_labels) |
| acc = accuracy_score(all_labels, all_preds) |
| f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0) |
| f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0) |
|
|
| return { |
| 'loss': total_loss / max(n, 1), |
| 'accuracy': acc, |
| 'f1_macro': f1_macro, |
| 'f1_weighted': f1_weighted, |
| }, all_preds, all_labels |
|
|
|
|
| |
| |
| |
|
|
| def run_experiment(args): |
| set_seed(args.seed) |
| init_classes(coarse=args.coarse) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| modalities = args.modalities.split(',') |
|
|
| granularity = "8 coarse" if args.coarse else "20 fine" |
| task_name = "Recognition" if args.mode == 'recognition' else "Prediction" |
| print(f"\n{'='*60}", flush=True) |
| print(f"Action {task_name} — Verb Classification ({granularity} classes)", flush=True) |
| print(f"Modalities: {modalities} | prev_action: {args.use_prev_action}", flush=True) |
| print(f"Window: {args.window_sec}s | d_model: {args.hidden_dim} | " |
| f"augment: {args.augment}", flush=True) |
| print(f"{'='*60}", flush=True) |
|
|
| |
| train_ds = ActionPredDataset( |
| TRAIN_VOLS, modalities, |
| window_sec=args.window_sec, downsample=args.downsample, |
| coarse=args.coarse, mode=args.mode) |
| stats = train_ds.get_stats() |
| val_ds = ActionPredDataset( |
| VAL_VOLS, modalities, |
| window_sec=args.window_sec, downsample=args.downsample, stats=stats, |
| coarse=args.coarse, mode=args.mode) |
| test_ds = ActionPredDataset( |
| TEST_VOLS, modalities, |
| window_sec=args.window_sec, downsample=args.downsample, stats=stats, |
| coarse=args.coarse, mode=args.mode) |
|
|
| if len(train_ds) == 0: |
| print("ERROR: No training samples!", flush=True) |
| return None |
|
|
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, |
| shuffle=True, drop_last=False) |
| val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False) |
| test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False) |
|
|
| |
| model = TransformerClassifier( |
| train_ds.feat_dim, NUM_ACTION_CLASSES, |
| d_model=args.hidden_dim, nhead=4, num_layers=2, dropout=args.dropout, |
| use_prev_action=args.use_prev_action, |
| ).to(device) |
| param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Trainable params: {param_count:,}", flush=True) |
|
|
| |
| from collections import Counter |
| label_dist = Counter(train_ds.labels) |
| weights = torch.zeros(NUM_ACTION_CLASSES) |
| for idx, cnt in label_dist.items(): |
| weights[idx] = 1.0 / max(cnt, 1) |
| weights = weights / weights.sum() * NUM_ACTION_CLASSES |
| criterion = nn.CrossEntropyLoss( |
| weight=weights.to(device), |
| label_smoothing=args.label_smoothing) |
|
|
| optimizer = torch.optim.AdamW( |
| model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, patience=7, factor=0.5, min_lr=1e-6) |
|
|
| mod_str = '-'.join(modalities) |
| tag = "coarse" if args.coarse else "fine" |
| prev_tag = "_prev" if args.use_prev_action else "" |
| mode_tag = "recog" if args.mode == 'recognition' else "pred" |
| extra_tag = f"_{args.tag}" if args.tag else "" |
| exp_name = f"{mode_tag}_cls_{tag}{prev_tag}_{mod_str}{extra_tag}" |
| out_dir = os.path.join(args.output_dir, exp_name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| best_val_f1 = -1 |
| best_epoch = 0 |
| patience_ctr = 0 |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| tr_loss, tr_acc = train_epoch( |
| model, train_loader, optimizer, criterion, device, |
| augment=args.augment, noise_std=args.noise_std, |
| time_mask_ratio=args.time_mask_ratio) |
|
|
| val_m, _, _ = evaluate(model, val_loader, criterion, device) |
| dt = time.time() - t0 |
|
|
| print(f" Epoch {epoch:3d} | TrLoss={tr_loss:.4f} TrAcc={tr_acc:.4f} | " |
| f"Val: loss={val_m['loss']:.4f} acc={val_m['accuracy']:.4f} " |
| f"F1m={val_m['f1_macro']:.4f} F1w={val_m['f1_weighted']:.4f} | " |
| f"{dt:.1f}s", flush=True) |
|
|
| scheduler.step(val_m['loss']) |
|
|
| if val_m['f1_weighted'] > best_val_f1: |
| best_val_f1 = val_m['f1_weighted'] |
| best_epoch = epoch |
| patience_ctr = 0 |
| torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt')) |
| else: |
| patience_ctr += 1 |
| if patience_ctr >= args.patience: |
| print(f" Early stopping at epoch {epoch}", flush=True) |
| break |
|
|
| |
| model.load_state_dict(torch.load( |
| os.path.join(out_dir, 'model_best.pt'), weights_only=True)) |
| test_m, test_preds, test_labels = evaluate( |
| model, test_loader, criterion, device) |
|
|
| print(f"\n--- Test (best epoch {best_epoch}) ---", flush=True) |
| for k, v in test_m.items(): |
| print(f" {k}: {v:.4f}", flush=True) |
|
|
| |
| present_classes = sorted(set(test_labels) | set(test_preds)) |
| target_names = [ACTION_CLASSES[i] for i in present_classes] |
| report = classification_report( |
| test_labels, test_preds, |
| labels=present_classes, target_names=target_names, |
| zero_division=0, output_dict=True) |
| print("\nPer-class results:", flush=True) |
| for cls_name in target_names: |
| r = report[cls_name] |
| print(f" {cls_name:<6}: P={r['precision']:.3f} R={r['recall']:.3f} " |
| f"F1={r['f1-score']:.3f} N={r['support']}", flush=True) |
|
|
| |
| print("\nSample predictions:", flush=True) |
| indices = random.sample(range(len(test_preds)), min(15, len(test_preds))) |
| for i in indices: |
| p_name = ACTION_CLASSES[test_preds[i]] |
| r_name = ACTION_CLASSES[test_labels[i]] |
| tag = "OK" if test_preds[i] == test_labels[i] else "XX" |
| orig_text = test_ds.texts[i] if i < len(test_ds.texts) else "?" |
| print(f" [{tag}] Pred={p_name:<6} Ref={r_name:<6} ({orig_text})", flush=True) |
|
|
| results = { |
| 'experiment': exp_name, |
| 'modalities': modalities, |
| 'best_epoch': best_epoch, |
| 'test_metrics': {k: float(v) for k, v in test_m.items()}, |
| 'trainable_params': param_count, |
| 'train_samples': len(train_ds), |
| 'val_samples': len(val_ds), |
| 'test_samples': len(test_ds), |
| 'num_classes': NUM_ACTION_CLASSES, |
| 'class_names': ACTION_CLASSES, |
| 'per_class_report': {k: v for k, v in report.items() |
| if k in target_names}, |
| 'args': vars(args), |
| } |
| with open(os.path.join(out_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2, ensure_ascii=False) |
| print(f" Saved to {out_dir}", flush=True) |
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--modalities', type=str, default='imu') |
| parser.add_argument('--window_sec', type=float, default=15.0) |
| parser.add_argument('--epochs', type=int, default=80) |
| parser.add_argument('--batch_size', type=int, default=32) |
| parser.add_argument('--lr', type=float, default=1e-3) |
| parser.add_argument('--weight_decay', type=float, default=1e-4) |
| parser.add_argument('--hidden_dim', type=int, default=64) |
| parser.add_argument('--dropout', type=float, default=0.2) |
| parser.add_argument('--downsample', type=int, default=5) |
| parser.add_argument('--patience', type=int, default=20) |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--augment', action='store_true') |
| parser.add_argument('--noise_std', type=float, default=0.1) |
| parser.add_argument('--time_mask_ratio', type=float, default=0.1) |
| parser.add_argument('--label_smoothing', type=float, default=0.1) |
| parser.add_argument('--mode', type=str, default='prediction', |
| choices=['prediction', 'recognition'], |
| help='prediction=next action, recognition=current action') |
| parser.add_argument('--coarse', action='store_true', |
| help='Use 8 coarse classes instead of 20 fine classes') |
| parser.add_argument('--use_prev_action', action='store_true', |
| help='Use previous action label as additional input') |
| parser.add_argument('--output_dir', type=str, |
| default='${PULSE_ROOT}/results/pred_cls') |
| parser.add_argument('--tag', type=str, default='', |
| help='Optional tag appended to experiment name') |
| args = parser.parse_args() |
| os.makedirs(args.output_dir, exist_ok=True) |
| run_experiment(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|