PULSE-code / experiments /analysis /gen_val_comparison.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
import os, sys, json, torch
sys.path.insert(0, '${PULSE_ROOT}')
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '1'
from tasks.train_pred import (
TextPredictionDataset, SensorToTextModel, apply_lora, set_seed
)
from data.dataset import TRAIN_VOLS, VAL_VOLS, TEST_VOLS
set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load tokenizer & LLM
from transformers import AutoTokenizer, AutoModelForCausalLM
llm_path = '${PULSE_ROOT}/models/qwen2.5-0.5b'
tokenizer = AutoTokenizer.from_pretrained(llm_path, trust_remote_code=True, local_files_only=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
llm = AutoModelForCausalLM.from_pretrained(
llm_path, trust_remote_code=True, torch_dtype=torch.float32, local_files_only=True
).to(device)
llm.config.pad_token_id = tokenizer.pad_token_id
for p in llm.parameters():
p.requires_grad = False
lora_params = apply_lora(llm, r=8, alpha=16)
modalities = ['mocap', 'emg', 'imu']
# Build datasets
train_ds = TextPredictionDataset(TRAIN_VOLS, modalities, tokenizer, window_sec=15.0, downsample=5)
stats = train_ds.get_stats()
val_ds = TextPredictionDataset(VAL_VOLS, modalities, tokenizer, window_sec=15.0, downsample=5, stats=stats)
test_ds = TextPredictionDataset(TEST_VOLS, modalities, tokenizer, window_sec=15.0, downsample=5, stats=stats)
# Build model & load weights
model = SensorToTextModel(train_ds.feat_dim, llm, tokenizer, n_sensor_tokens=8, d_model=64)
model.to(device)
ckpt_path = '${PULSE_ROOT}/results/pred_llm2/pred_llm_mocap-emg-imu/model_best.pt'
sd = torch.load(ckpt_path, weights_only=True, map_location=device)
model.load_state_dict(sd, strict=False)
model.eval()
out_path = '${PULSE_ROOT}/docs/pred_llm2_val_comparison.txt'
from torch.utils.data import DataLoader
with open(out_path, 'w') as f:
for split_name, ds in [('Validation', val_ds), ('Test', test_ds)]:
loader = DataLoader(ds, batch_size=8, shuffle=False)
f.write(f"{'='*70}\n")
f.write(f"{split_name} Set — mocap,emg,imu (best charF1=0.0324)\n")
f.write(f"Samples: {len(ds)}\n")
f.write(f"{'='*70}\n\n")
idx = 0
for batch in loader:
sensor = batch['sensor'].to(device)
preds = model.generate_text(sensor, tokenizer, max_new_tokens=20)
refs = [ds.texts[idx + i] for i in range(len(preds))]
for p, r in zip(preds, refs):
match = "OK" if p.strip() == r.strip() else "XX"
f.write(f"[{match}] #{idx+1}\n")
f.write(f" Pred: {p.strip()}\n")
f.write(f" Ref: {r.strip()}\n\n")
idx += 1
# Stats
f.write(f"\n--- {split_name} Summary ---\n")
f.write(f"Total: {idx}\n\n")
print(f"Written to {out_path}")