PULSE-code / experiments /tasks /train_pred.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Sensor-to-text action prediction with LoRA-tuned LLM.
Improvements over v1:
1. LoRA on LLM q_proj/v_proj — lets LLM learn to understand sensor tokens
2. Instruction prefix "描述接下来的动作:" — guides generation
3. Short generation limit (max 20 tokens) — prevents rambling
Architecture:
SensorEncoder → pool to K soft-prompt tokens → project to LLM space
→ [sensor_tokens] + [instruction] → LoRA-tuned Qwen2.5-0.5B → action text
"""
import os
import sys
import json
import time
import math
import re
import random
import argparse
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import (
DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, VAL_VOLS, TEST_VOLS,
load_modality_array,
)
ANNOTATION_DIR = "${PULSE_ROOT}"
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_timestamp(ts_str):
parts = ts_str.strip().split(':')
if len(parts) == 2:
return int(parts[0]) * 60 + int(parts[1])
elif len(parts) == 3:
return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
return 0
# ============================================================
# LoRA
# ============================================================
class LoRALayer(nn.Module):
"""Low-Rank Adaptation wrapper for nn.Linear."""
def __init__(self, base_layer, r=8, alpha=16, dropout=0.1):
super().__init__()
self.base_layer = base_layer
for p in self.base_layer.parameters():
p.requires_grad = False
in_dim = base_layer.in_features
out_dim = base_layer.out_features
self.lora_A = nn.Linear(in_dim, r, bias=False)
self.lora_B = nn.Linear(r, out_dim, bias=False)
self.scaling = alpha / r
self.lora_dropout = nn.Dropout(dropout)
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x):
base_out = self.base_layer(x)
lora_out = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
return base_out + lora_out
def apply_lora(llm, r=8, alpha=16, dropout=0.1):
"""Apply LoRA to q_proj and v_proj in all attention layers. Returns LoRA params."""
lora_params = []
for layer in llm.model.layers:
attn = layer.self_attn
for name in ['q_proj', 'v_proj']:
original = getattr(attn, name)
lora_layer = LoRALayer(original, r=r, alpha=alpha, dropout=dropout)
setattr(attn, name, lora_layer)
lora_params.extend(lora_layer.lora_A.parameters())
lora_params.extend(lora_layer.lora_B.parameters())
return lora_params
# ============================================================
# Dataset
# ============================================================
class TextPredictionDataset(Dataset):
def __init__(self, volunteers, modalities, tokenizer,
window_sec=15.0, max_text_len=48,
downsample=5, sampling_rate=100, stats=None):
self.tokenizer = tokenizer
self.max_text_len = max_text_len
self._feat_dim = None
raw_samples = []
all_features_for_stats = []
window_frames = int(window_sec * sampling_rate / downsample)
for vol in volunteers:
vol_dir = os.path.join(DATASET_DIR, vol)
if not os.path.isdir(vol_dir):
continue
for scenario in sorted(os.listdir(vol_dir)):
scenario_dir = os.path.join(vol_dir, scenario)
if not os.path.isdir(scenario_dir):
continue
meta_path = os.path.join(scenario_dir, 'alignment_metadata.json')
if not os.path.exists(meta_path):
continue
with open(meta_path) as f:
meta = json.load(f)
if not set(modalities).issubset(set(meta['modalities'])):
continue
parts = []
for mod in modalities:
filepath = os.path.join(scenario_dir, MODALITY_FILES[mod])
arr = load_modality_array(filepath, mod)
parts.append(arr)
min_len = min(p.shape[0] for p in parts)
features = np.concatenate([p[:min_len] for p in parts], axis=1)
features = features[::downsample]
if self._feat_dim is None:
self._feat_dim = features.shape[1]
all_features_for_stats.append(features)
ann_path = os.path.join(ANNOTATION_DIR, vol, f"{scenario}.json")
if not os.path.exists(ann_path):
continue
with open(ann_path) as f:
ann = json.load(f)
segments = []
for seg in ann.get('segments', []):
m = re.match(r'(\d+:\d+(?::\d+)?)\s*-\s*(\d+:\d+(?::\d+)?)',
seg['timestamp'])
if not m:
continue
start_sec = parse_timestamp(m.group(1))
start_frame = int(start_sec * sampling_rate / downsample)
segments.append((start_frame, seg['task']))
if len(segments) < 2:
continue
T_total = features.shape[0]
for i in range(1, len(segments)):
boundary = segments[i][0]
if boundary > T_total:
break
end = boundary
start = max(0, end - window_frames)
window = features[start:end]
if window.shape[0] == 0:
continue
if window.shape[0] < window_frames:
pad = np.zeros((window_frames - window.shape[0], self._feat_dim))
window = np.concatenate([pad, window], axis=0)
raw_samples.append((window.astype(np.float32), segments[i][1]))
# Normalization
if stats is not None:
self.mean, self.std = stats
else:
if all_features_for_stats:
cat = np.concatenate(all_features_for_stats, axis=0).astype(np.float64)
self.mean = np.mean(cat, axis=0, keepdims=True)
self.std = np.std(cat, axis=0, keepdims=True)
self.std[self.std < 1e-8] = 1.0
else:
d = self._feat_dim or 1
self.mean = np.zeros((1, d))
self.std = np.ones((1, d))
self.sensor_data = [
((x - self.mean) / self.std).astype(np.float32) for x, _ in raw_samples
]
self.texts = [t for _, t in raw_samples]
# Tokenize: text + EOS
eos = tokenizer.eos_token or ''
self.tokenized = tokenizer(
[t + eos for t in self.texts],
padding='max_length', max_length=max_text_len,
truncation=True, return_tensors='np', add_special_tokens=False,
)
print(f" {len(self.sensor_data)} samples, feat_dim={self._feat_dim}, "
f"window={window_frames}f, unique_texts={len(set(self.texts))}",
flush=True)
def get_stats(self):
return (self.mean, self.std)
@property
def feat_dim(self):
return self._feat_dim
def __len__(self):
return len(self.sensor_data)
def __getitem__(self, idx):
return {
'sensor': torch.from_numpy(self.sensor_data[idx]),
'input_ids': torch.tensor(
self.tokenized['input_ids'][idx], dtype=torch.long),
'attention_mask': torch.tensor(
self.tokenized['attention_mask'][idx], dtype=torch.long),
}
# ============================================================
# Model
# ============================================================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return self.dropout(x + self.pe[:, :x.size(1)])
class SensorEncoder(nn.Module):
def __init__(self, input_dim, d_model=64, nhead=4, num_layers=2, dropout=0.1):
super().__init__()
self.proj = nn.Linear(input_dim, d_model)
self.pos = PositionalEncoding(d_model, dropout)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
dropout=dropout, batch_first=True)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
def forward(self, x):
return self.encoder(self.pos(self.proj(x)))
class SensorToTextModel(nn.Module):
def __init__(self, input_dim, llm, tokenizer, n_sensor_tokens=8,
d_model=64, nhead=4, num_layers=2, dropout=0.1):
super().__init__()
self.n_sensor_tokens = n_sensor_tokens
lm_hidden = llm.config.hidden_size
self.sensor_encoder = SensorEncoder(
input_dim, d_model, nhead, num_layers, dropout)
self.pool = nn.AdaptiveAvgPool1d(n_sensor_tokens)
self.projection = nn.Linear(d_model, lm_hidden)
self.llm = llm
# Pre-tokenize instruction prefix
inst_text = "描述接下来的动作:"
inst_ids = tokenizer(inst_text, add_special_tokens=False,
return_tensors='pt')['input_ids']
self.register_buffer('instruction_ids', inst_ids) # (1, L_inst)
self.n_inst = inst_ids.size(1)
@property
def prefix_len(self):
return self.n_sensor_tokens + self.n_inst
def encode_sensor(self, x):
feat = self.sensor_encoder(x)
feat = self.pool(feat.transpose(1, 2)).transpose(1, 2)
return self.projection(feat)
def forward(self, sensor, input_ids, attention_mask):
B = sensor.size(0)
device = sensor.device
sensor_embeds = self.encode_sensor(sensor) # (B, K, H)
inst_ids = self.instruction_ids.expand(B, -1) # (B, L_inst)
inst_embeds = self.llm.get_input_embeddings()(inst_ids)
text_embeds = self.llm.get_input_embeddings()(input_ids)
input_embeds = torch.cat(
[sensor_embeds, inst_embeds, text_embeds], dim=1)
P = self.prefix_len
prefix_attn = torch.ones(B, P, device=device, dtype=attention_mask.dtype)
full_attn = torch.cat([prefix_attn, attention_mask], dim=1)
return self.llm(inputs_embeds=input_embeds,
attention_mask=full_attn).logits
@torch.no_grad()
def generate_text(self, sensor, tokenizer, max_new_tokens=20):
self.eval()
B = sensor.size(0)
device = sensor.device
sensor_embeds = self.encode_sensor(sensor)
inst_ids = self.instruction_ids.expand(B, -1)
inst_embeds = self.llm.get_input_embeddings()(inst_ids)
prefix = torch.cat([sensor_embeds, inst_embeds], dim=1)
eos_id = tokenizer.eos_token_id
# First pass
out = self.llm(inputs_embeds=prefix, use_cache=True)
past_kv = out.past_key_values
next_id = out.logits[:, -1, :].argmax(-1)
generated = [next_id]
for _ in range(max_new_tokens - 1):
if (next_id == eos_id).all():
break
next_emb = self.llm.get_input_embeddings()(next_id).unsqueeze(1)
out = self.llm(inputs_embeds=next_emb,
past_key_values=past_kv, use_cache=True)
past_kv = out.past_key_values
next_id = out.logits[:, -1, :].argmax(-1)
generated.append(next_id)
gen_ids = torch.stack(generated, dim=1)
texts = []
for i in range(B):
ids = gen_ids[i].tolist()
if eos_id in ids:
ids = ids[:ids.index(eos_id)]
texts.append(tokenizer.decode(ids, skip_special_tokens=True))
return texts
# ============================================================
# Training & Evaluation
# ============================================================
def train_epoch(model, loader, optimizer, device):
model.train()
total_loss, n = 0, 0
P = model.prefix_len
pad_id = model.llm.config.pad_token_id or 0
for batch in loader:
sensor = batch['sensor'].to(device)
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
optimizer.zero_grad()
logits = model(sensor, input_ids, attention_mask)
L = input_ids.size(1)
pred = logits[:, P - 1: P - 1 + L, :]
loss = F.cross_entropy(
pred.reshape(-1, pred.size(-1)),
input_ids.reshape(-1),
ignore_index=pad_id)
loss.backward()
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad], 1.0)
optimizer.step()
total_loss += loss.item() * sensor.size(0)
n += sensor.size(0)
return total_loss / max(n, 1)
@torch.no_grad()
def eval_loss_only(model, loader, device):
model.eval()
total_loss, n = 0, 0
P = model.prefix_len
pad_id = model.llm.config.pad_token_id or 0
for batch in loader:
sensor = batch['sensor'].to(device)
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
logits = model(sensor, input_ids, attention_mask)
L = input_ids.size(1)
pred = logits[:, P - 1: P - 1 + L, :]
loss = F.cross_entropy(
pred.reshape(-1, pred.size(-1)),
input_ids.reshape(-1), ignore_index=pad_id)
total_loss += loss.item() * sensor.size(0)
n += sensor.size(0)
return total_loss / max(n, 1)
@torch.no_grad()
def eval_with_generation(model, loader, tokenizer, device):
model.eval()
total_loss, n = 0, 0
P = model.prefix_len
pad_id = model.llm.config.pad_token_id or 0
all_preds, all_refs = [], []
for batch in loader:
sensor = batch['sensor'].to(device)
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
logits = model(sensor, input_ids, attention_mask)
L = input_ids.size(1)
pred = logits[:, P - 1: P - 1 + L, :]
loss = F.cross_entropy(
pred.reshape(-1, pred.size(-1)),
input_ids.reshape(-1), ignore_index=pad_id)
total_loss += loss.item() * sensor.size(0)
n += sensor.size(0)
texts = model.generate_text(sensor, tokenizer, max_new_tokens=20)
all_preds.extend(texts)
refs = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
all_refs.extend(refs)
em = sum(p.strip() == r.strip()
for p, r in zip(all_preds, all_refs)) / max(len(all_preds), 1)
char_correct, char_ptot, char_rtot = 0, 0, 0
for p, r in zip(all_preds, all_refs):
ps, rs = p.strip(), r.strip()
for j in range(min(len(ps), len(rs))):
if ps[j] == rs[j]:
char_correct += 1
char_ptot += len(ps)
char_rtot += len(rs)
prec = char_correct / max(char_ptot, 1)
rec = char_correct / max(char_rtot, 1)
char_f1 = 2 * prec * rec / max(prec + rec, 1e-8)
return {
'loss': total_loss / max(n, 1),
'exact_match': em,
'char_precision': prec,
'char_recall': rec,
'char_f1': char_f1,
}, all_preds, all_refs
# ============================================================
# Main
# ============================================================
def run_experiment(args):
set_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modalities = args.modalities.split(',')
print(f"\n{'='*60}", flush=True)
print(f"Sensor → LLM Text (LoRA + instruction prefix)", flush=True)
print(f"Mods: {modalities} | LLM: {args.llm_name}", flush=True)
print(f"LoRA r={args.lora_r} alpha={args.lora_alpha}", flush=True)
print(f"{'='*60}", flush=True)
# LLM
print("Loading LLM...", flush=True)
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(
args.llm_name, trust_remote_code=True, local_files_only=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
llm = AutoModelForCausalLM.from_pretrained(
args.llm_name, trust_remote_code=True,
torch_dtype=torch.float32, local_files_only=True,
).to(device)
llm.config.pad_token_id = tokenizer.pad_token_id
# Freeze all LLM params first
for p in llm.parameters():
p.requires_grad = False
# Apply LoRA
lora_params = apply_lora(llm, r=args.lora_r, alpha=args.lora_alpha)
lora_param_count = sum(p.numel() for p in lora_params)
print(f"LoRA params: {lora_param_count:,} (r={args.lora_r})", flush=True)
# Datasets
train_ds = TextPredictionDataset(
TRAIN_VOLS, modalities, tokenizer,
window_sec=args.window_sec, max_text_len=args.max_text_len,
downsample=args.downsample)
stats = train_ds.get_stats()
val_ds = TextPredictionDataset(
VAL_VOLS, modalities, tokenizer,
window_sec=args.window_sec, max_text_len=args.max_text_len,
downsample=args.downsample, stats=stats)
test_ds = TextPredictionDataset(
TEST_VOLS, modalities, tokenizer,
window_sec=args.window_sec, max_text_len=args.max_text_len,
downsample=args.downsample, stats=stats)
if len(train_ds) == 0:
print("ERROR: No training samples!", flush=True)
return None
train_loader = DataLoader(train_ds, batch_size=args.batch_size,
shuffle=True, drop_last=False)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)
# Model
model = SensorToTextModel(
train_ds.feat_dim, llm, tokenizer,
n_sensor_tokens=args.n_sensor_tokens, d_model=args.hidden_dim)
model = model.to(device) # move ALL submodules + buffers to GPU
# Collect trainable params
sensor_params = list(model.sensor_encoder.parameters()) + \
list(model.projection.parameters())
all_trainable = sensor_params + lora_params
trainable_count = sum(p.numel() for p in all_trainable)
total_count = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable_count:,} / Total: {total_count:,}", flush=True)
optimizer = torch.optim.AdamW([
{'params': sensor_params, 'lr': args.lr},
{'params': lora_params, 'lr': args.lr * 0.2},
], weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=7, factor=0.5, min_lr=1e-6)
mod_str = '-'.join(modalities)
exp_name = f"pred_llm_{mod_str}"
out_dir = os.path.join(args.output_dir, exp_name)
os.makedirs(out_dir, exist_ok=True)
best_val_loss = float('inf')
best_epoch = 0
patience_ctr = 0
for epoch in range(1, args.epochs + 1):
t0 = time.time()
tr_loss = train_epoch(model, train_loader, optimizer, device)
if epoch % 5 == 0 or epoch <= 2 or patience_ctr >= args.patience - 2:
val_m, _, _ = eval_with_generation(
model, val_loader, tokenizer, device)
print(f" Epoch {epoch:3d} | TrLoss={tr_loss:.4f} | "
f"Val: loss={val_m['loss']:.4f} EM={val_m['exact_match']:.4f} "
f"charF1={val_m['char_f1']:.4f} | {time.time()-t0:.1f}s",
flush=True)
else:
val_loss = eval_loss_only(model, val_loader, device)
val_m = {'loss': val_loss}
print(f" Epoch {epoch:3d} | TrLoss={tr_loss:.4f} | "
f"Val: loss={val_loss:.4f} | {time.time()-t0:.1f}s",
flush=True)
scheduler.step(val_m['loss'])
if val_m['loss'] < best_val_loss:
best_val_loss = val_m['loss']
best_epoch = epoch
patience_ctr = 0
# Save sensor encoder + projection + LoRA weights
save_sd = {}
for k, v in model.state_dict().items():
if k.startswith('llm.'):
if 'lora_A' in k or 'lora_B' in k:
save_sd[k] = v
else:
save_sd[k] = v
torch.save(save_sd, os.path.join(out_dir, 'model_best.pt'))
else:
patience_ctr += 1
if patience_ctr >= args.patience:
print(f" Early stopping at epoch {epoch}", flush=True)
break
# Test
best_sd = torch.load(os.path.join(out_dir, 'model_best.pt'),
weights_only=True)
model.load_state_dict(best_sd, strict=False)
test_m, test_preds, test_refs = eval_with_generation(
model, test_loader, tokenizer, device)
print(f"\n--- Test (best epoch {best_epoch}) ---", flush=True)
for k, v in test_m.items():
print(f" {k}: {v:.4f}", flush=True)
print("\nSample predictions:", flush=True)
indices = random.sample(range(len(test_preds)), min(15, len(test_preds)))
for i in indices:
tag = "OK" if test_preds[i].strip() == test_refs[i].strip() else "XX"
print(f" [{tag}] Pred: {test_preds[i].strip()}", flush=True)
print(f" Ref: {test_refs[i].strip()}", flush=True)
results = {
'experiment': exp_name,
'modalities': modalities,
'best_epoch': best_epoch,
'test_metrics': {k: float(v) for k, v in test_m.items()},
'trainable_params': trainable_count,
'lora_params': lora_param_count,
'train_samples': len(train_ds),
'val_samples': len(val_ds),
'test_samples': len(test_ds),
'args': vars(args),
'sample_predictions': [
{'pred': test_preds[i].strip(), 'ref': test_refs[i].strip()}
for i in indices
],
}
with open(os.path.join(out_dir, 'results.json'), 'w') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f" Saved to {out_dir}", flush=True)
return results
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--modalities', type=str, default='imu')
parser.add_argument('--window_sec', type=float, default=15.0)
parser.add_argument('--llm_name', type=str,
default='${PULSE_ROOT}/models/qwen2.5-0.5b')
parser.add_argument('--lora_r', type=int, default=8)
parser.add_argument('--lora_alpha', type=int, default=16)
parser.add_argument('--n_sensor_tokens', type=int, default=8)
parser.add_argument('--max_text_len', type=int, default=48)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--lr', type=float, default=5e-4)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--downsample', type=int, default=5)
parser.add_argument('--patience', type=int, default=15)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--output_dir', type=str,
default='${PULSE_ROOT}/results/pred_llm2')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
run_experiment(args)
if __name__ == '__main__':
main()