| |
| """ |
| Sensor-to-text action prediction with LoRA-tuned LLM. |
| |
| Improvements over v1: |
| 1. LoRA on LLM q_proj/v_proj — lets LLM learn to understand sensor tokens |
| 2. Instruction prefix "描述接下来的动作:" — guides generation |
| 3. Short generation limit (max 20 tokens) — prevents rambling |
| |
| Architecture: |
| SensorEncoder → pool to K soft-prompt tokens → project to LLM space |
| → [sensor_tokens] + [instruction] → LoRA-tuned Qwen2.5-0.5B → action text |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import math |
| import re |
| import random |
| import argparse |
| import glob |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import ( |
| DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, VAL_VOLS, TEST_VOLS, |
| load_modality_array, |
| ) |
|
|
| ANNOTATION_DIR = "${PULSE_ROOT}" |
|
|
|
|
| def set_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def parse_timestamp(ts_str): |
| parts = ts_str.strip().split(':') |
| if len(parts) == 2: |
| return int(parts[0]) * 60 + int(parts[1]) |
| elif len(parts) == 3: |
| return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2]) |
| return 0 |
|
|
|
|
| |
| |
| |
|
|
| class LoRALayer(nn.Module): |
| """Low-Rank Adaptation wrapper for nn.Linear.""" |
|
|
| def __init__(self, base_layer, r=8, alpha=16, dropout=0.1): |
| super().__init__() |
| self.base_layer = base_layer |
| for p in self.base_layer.parameters(): |
| p.requires_grad = False |
|
|
| in_dim = base_layer.in_features |
| out_dim = base_layer.out_features |
| self.lora_A = nn.Linear(in_dim, r, bias=False) |
| self.lora_B = nn.Linear(r, out_dim, bias=False) |
| self.scaling = alpha / r |
| self.lora_dropout = nn.Dropout(dropout) |
|
|
| nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B.weight) |
|
|
| def forward(self, x): |
| base_out = self.base_layer(x) |
| lora_out = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling |
| return base_out + lora_out |
|
|
|
|
| def apply_lora(llm, r=8, alpha=16, dropout=0.1): |
| """Apply LoRA to q_proj and v_proj in all attention layers. Returns LoRA params.""" |
| lora_params = [] |
| for layer in llm.model.layers: |
| attn = layer.self_attn |
| for name in ['q_proj', 'v_proj']: |
| original = getattr(attn, name) |
| lora_layer = LoRALayer(original, r=r, alpha=alpha, dropout=dropout) |
| setattr(attn, name, lora_layer) |
| lora_params.extend(lora_layer.lora_A.parameters()) |
| lora_params.extend(lora_layer.lora_B.parameters()) |
| return lora_params |
|
|
|
|
| |
| |
| |
|
|
| class TextPredictionDataset(Dataset): |
| def __init__(self, volunteers, modalities, tokenizer, |
| window_sec=15.0, max_text_len=48, |
| downsample=5, sampling_rate=100, stats=None): |
| self.tokenizer = tokenizer |
| self.max_text_len = max_text_len |
| self._feat_dim = None |
| raw_samples = [] |
| all_features_for_stats = [] |
| window_frames = int(window_sec * sampling_rate / downsample) |
|
|
| for vol in volunteers: |
| vol_dir = os.path.join(DATASET_DIR, vol) |
| if not os.path.isdir(vol_dir): |
| continue |
| for scenario in sorted(os.listdir(vol_dir)): |
| scenario_dir = os.path.join(vol_dir, scenario) |
| if not os.path.isdir(scenario_dir): |
| continue |
| meta_path = os.path.join(scenario_dir, 'alignment_metadata.json') |
| if not os.path.exists(meta_path): |
| continue |
| with open(meta_path) as f: |
| meta = json.load(f) |
| if not set(modalities).issubset(set(meta['modalities'])): |
| continue |
|
|
| parts = [] |
| for mod in modalities: |
| filepath = os.path.join(scenario_dir, MODALITY_FILES[mod]) |
| arr = load_modality_array(filepath, mod) |
| parts.append(arr) |
| min_len = min(p.shape[0] for p in parts) |
| features = np.concatenate([p[:min_len] for p in parts], axis=1) |
| features = features[::downsample] |
| if self._feat_dim is None: |
| self._feat_dim = features.shape[1] |
| all_features_for_stats.append(features) |
|
|
| ann_path = os.path.join(ANNOTATION_DIR, vol, f"{scenario}.json") |
| if not os.path.exists(ann_path): |
| continue |
| with open(ann_path) as f: |
| ann = json.load(f) |
| segments = [] |
| for seg in ann.get('segments', []): |
| m = re.match(r'(\d+:\d+(?::\d+)?)\s*-\s*(\d+:\d+(?::\d+)?)', |
| seg['timestamp']) |
| if not m: |
| continue |
| start_sec = parse_timestamp(m.group(1)) |
| start_frame = int(start_sec * sampling_rate / downsample) |
| segments.append((start_frame, seg['task'])) |
| if len(segments) < 2: |
| continue |
|
|
| T_total = features.shape[0] |
| for i in range(1, len(segments)): |
| boundary = segments[i][0] |
| if boundary > T_total: |
| break |
| end = boundary |
| start = max(0, end - window_frames) |
| window = features[start:end] |
| if window.shape[0] == 0: |
| continue |
| if window.shape[0] < window_frames: |
| pad = np.zeros((window_frames - window.shape[0], self._feat_dim)) |
| window = np.concatenate([pad, window], axis=0) |
| raw_samples.append((window.astype(np.float32), segments[i][1])) |
|
|
| |
| if stats is not None: |
| self.mean, self.std = stats |
| else: |
| if all_features_for_stats: |
| cat = np.concatenate(all_features_for_stats, axis=0).astype(np.float64) |
| self.mean = np.mean(cat, axis=0, keepdims=True) |
| self.std = np.std(cat, axis=0, keepdims=True) |
| self.std[self.std < 1e-8] = 1.0 |
| else: |
| d = self._feat_dim or 1 |
| self.mean = np.zeros((1, d)) |
| self.std = np.ones((1, d)) |
|
|
| self.sensor_data = [ |
| ((x - self.mean) / self.std).astype(np.float32) for x, _ in raw_samples |
| ] |
| self.texts = [t for _, t in raw_samples] |
|
|
| |
| eos = tokenizer.eos_token or '' |
| self.tokenized = tokenizer( |
| [t + eos for t in self.texts], |
| padding='max_length', max_length=max_text_len, |
| truncation=True, return_tensors='np', add_special_tokens=False, |
| ) |
| print(f" {len(self.sensor_data)} samples, feat_dim={self._feat_dim}, " |
| f"window={window_frames}f, unique_texts={len(set(self.texts))}", |
| flush=True) |
|
|
| def get_stats(self): |
| return (self.mean, self.std) |
|
|
| @property |
| def feat_dim(self): |
| return self._feat_dim |
|
|
| def __len__(self): |
| return len(self.sensor_data) |
|
|
| def __getitem__(self, idx): |
| return { |
| 'sensor': torch.from_numpy(self.sensor_data[idx]), |
| 'input_ids': torch.tensor( |
| self.tokenized['input_ids'][idx], dtype=torch.long), |
| 'attention_mask': torch.tensor( |
| self.tokenized['attention_mask'][idx], dtype=torch.long), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, dropout=0.1, max_len=5000): |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
| pe = torch.zeros(max_len, d_model) |
| pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div = torch.exp(torch.arange(0, d_model, 2).float() * |
| (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
| def forward(self, x): |
| return self.dropout(x + self.pe[:, :x.size(1)]) |
|
|
|
|
| class SensorEncoder(nn.Module): |
| def __init__(self, input_dim, d_model=64, nhead=4, num_layers=2, dropout=0.1): |
| super().__init__() |
| self.proj = nn.Linear(input_dim, d_model) |
| self.pos = PositionalEncoding(d_model, dropout) |
| layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, |
| dropout=dropout, batch_first=True) |
| self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers) |
|
|
| def forward(self, x): |
| return self.encoder(self.pos(self.proj(x))) |
|
|
|
|
| class SensorToTextModel(nn.Module): |
| def __init__(self, input_dim, llm, tokenizer, n_sensor_tokens=8, |
| d_model=64, nhead=4, num_layers=2, dropout=0.1): |
| super().__init__() |
| self.n_sensor_tokens = n_sensor_tokens |
| lm_hidden = llm.config.hidden_size |
|
|
| self.sensor_encoder = SensorEncoder( |
| input_dim, d_model, nhead, num_layers, dropout) |
| self.pool = nn.AdaptiveAvgPool1d(n_sensor_tokens) |
| self.projection = nn.Linear(d_model, lm_hidden) |
| self.llm = llm |
|
|
| |
| inst_text = "描述接下来的动作:" |
| inst_ids = tokenizer(inst_text, add_special_tokens=False, |
| return_tensors='pt')['input_ids'] |
| self.register_buffer('instruction_ids', inst_ids) |
| self.n_inst = inst_ids.size(1) |
|
|
| @property |
| def prefix_len(self): |
| return self.n_sensor_tokens + self.n_inst |
|
|
| def encode_sensor(self, x): |
| feat = self.sensor_encoder(x) |
| feat = self.pool(feat.transpose(1, 2)).transpose(1, 2) |
| return self.projection(feat) |
|
|
| def forward(self, sensor, input_ids, attention_mask): |
| B = sensor.size(0) |
| device = sensor.device |
|
|
| sensor_embeds = self.encode_sensor(sensor) |
| inst_ids = self.instruction_ids.expand(B, -1) |
| inst_embeds = self.llm.get_input_embeddings()(inst_ids) |
| text_embeds = self.llm.get_input_embeddings()(input_ids) |
|
|
| input_embeds = torch.cat( |
| [sensor_embeds, inst_embeds, text_embeds], dim=1) |
| P = self.prefix_len |
| prefix_attn = torch.ones(B, P, device=device, dtype=attention_mask.dtype) |
| full_attn = torch.cat([prefix_attn, attention_mask], dim=1) |
|
|
| return self.llm(inputs_embeds=input_embeds, |
| attention_mask=full_attn).logits |
|
|
| @torch.no_grad() |
| def generate_text(self, sensor, tokenizer, max_new_tokens=20): |
| self.eval() |
| B = sensor.size(0) |
| device = sensor.device |
|
|
| sensor_embeds = self.encode_sensor(sensor) |
| inst_ids = self.instruction_ids.expand(B, -1) |
| inst_embeds = self.llm.get_input_embeddings()(inst_ids) |
| prefix = torch.cat([sensor_embeds, inst_embeds], dim=1) |
|
|
| eos_id = tokenizer.eos_token_id |
|
|
| |
| out = self.llm(inputs_embeds=prefix, use_cache=True) |
| past_kv = out.past_key_values |
| next_id = out.logits[:, -1, :].argmax(-1) |
| generated = [next_id] |
|
|
| for _ in range(max_new_tokens - 1): |
| if (next_id == eos_id).all(): |
| break |
| next_emb = self.llm.get_input_embeddings()(next_id).unsqueeze(1) |
| out = self.llm(inputs_embeds=next_emb, |
| past_key_values=past_kv, use_cache=True) |
| past_kv = out.past_key_values |
| next_id = out.logits[:, -1, :].argmax(-1) |
| generated.append(next_id) |
|
|
| gen_ids = torch.stack(generated, dim=1) |
| texts = [] |
| for i in range(B): |
| ids = gen_ids[i].tolist() |
| if eos_id in ids: |
| ids = ids[:ids.index(eos_id)] |
| texts.append(tokenizer.decode(ids, skip_special_tokens=True)) |
| return texts |
|
|
|
|
| |
| |
| |
|
|
| def train_epoch(model, loader, optimizer, device): |
| model.train() |
| total_loss, n = 0, 0 |
| P = model.prefix_len |
| pad_id = model.llm.config.pad_token_id or 0 |
|
|
| for batch in loader: |
| sensor = batch['sensor'].to(device) |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
|
|
| optimizer.zero_grad() |
| logits = model(sensor, input_ids, attention_mask) |
|
|
| L = input_ids.size(1) |
| pred = logits[:, P - 1: P - 1 + L, :] |
| loss = F.cross_entropy( |
| pred.reshape(-1, pred.size(-1)), |
| input_ids.reshape(-1), |
| ignore_index=pad_id) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_( |
| [p for p in model.parameters() if p.requires_grad], 1.0) |
| optimizer.step() |
|
|
| total_loss += loss.item() * sensor.size(0) |
| n += sensor.size(0) |
| return total_loss / max(n, 1) |
|
|
|
|
| @torch.no_grad() |
| def eval_loss_only(model, loader, device): |
| model.eval() |
| total_loss, n = 0, 0 |
| P = model.prefix_len |
| pad_id = model.llm.config.pad_token_id or 0 |
| for batch in loader: |
| sensor = batch['sensor'].to(device) |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| logits = model(sensor, input_ids, attention_mask) |
| L = input_ids.size(1) |
| pred = logits[:, P - 1: P - 1 + L, :] |
| loss = F.cross_entropy( |
| pred.reshape(-1, pred.size(-1)), |
| input_ids.reshape(-1), ignore_index=pad_id) |
| total_loss += loss.item() * sensor.size(0) |
| n += sensor.size(0) |
| return total_loss / max(n, 1) |
|
|
|
|
| @torch.no_grad() |
| def eval_with_generation(model, loader, tokenizer, device): |
| model.eval() |
| total_loss, n = 0, 0 |
| P = model.prefix_len |
| pad_id = model.llm.config.pad_token_id or 0 |
| all_preds, all_refs = [], [] |
|
|
| for batch in loader: |
| sensor = batch['sensor'].to(device) |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
|
|
| logits = model(sensor, input_ids, attention_mask) |
| L = input_ids.size(1) |
| pred = logits[:, P - 1: P - 1 + L, :] |
| loss = F.cross_entropy( |
| pred.reshape(-1, pred.size(-1)), |
| input_ids.reshape(-1), ignore_index=pad_id) |
| total_loss += loss.item() * sensor.size(0) |
| n += sensor.size(0) |
|
|
| texts = model.generate_text(sensor, tokenizer, max_new_tokens=20) |
| all_preds.extend(texts) |
| refs = tokenizer.batch_decode(input_ids, skip_special_tokens=True) |
| all_refs.extend(refs) |
|
|
| em = sum(p.strip() == r.strip() |
| for p, r in zip(all_preds, all_refs)) / max(len(all_preds), 1) |
|
|
| char_correct, char_ptot, char_rtot = 0, 0, 0 |
| for p, r in zip(all_preds, all_refs): |
| ps, rs = p.strip(), r.strip() |
| for j in range(min(len(ps), len(rs))): |
| if ps[j] == rs[j]: |
| char_correct += 1 |
| char_ptot += len(ps) |
| char_rtot += len(rs) |
| prec = char_correct / max(char_ptot, 1) |
| rec = char_correct / max(char_rtot, 1) |
| char_f1 = 2 * prec * rec / max(prec + rec, 1e-8) |
|
|
| return { |
| 'loss': total_loss / max(n, 1), |
| 'exact_match': em, |
| 'char_precision': prec, |
| 'char_recall': rec, |
| 'char_f1': char_f1, |
| }, all_preds, all_refs |
|
|
|
|
| |
| |
| |
|
|
| def run_experiment(args): |
| set_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| modalities = args.modalities.split(',') |
|
|
| print(f"\n{'='*60}", flush=True) |
| print(f"Sensor → LLM Text (LoRA + instruction prefix)", flush=True) |
| print(f"Mods: {modalities} | LLM: {args.llm_name}", flush=True) |
| print(f"LoRA r={args.lora_r} alpha={args.lora_alpha}", flush=True) |
| print(f"{'='*60}", flush=True) |
|
|
| |
| print("Loading LLM...", flush=True) |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| tokenizer = AutoTokenizer.from_pretrained( |
| args.llm_name, trust_remote_code=True, local_files_only=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| llm = AutoModelForCausalLM.from_pretrained( |
| args.llm_name, trust_remote_code=True, |
| torch_dtype=torch.float32, local_files_only=True, |
| ).to(device) |
| llm.config.pad_token_id = tokenizer.pad_token_id |
|
|
| |
| for p in llm.parameters(): |
| p.requires_grad = False |
|
|
| |
| lora_params = apply_lora(llm, r=args.lora_r, alpha=args.lora_alpha) |
| lora_param_count = sum(p.numel() for p in lora_params) |
| print(f"LoRA params: {lora_param_count:,} (r={args.lora_r})", flush=True) |
|
|
| |
| train_ds = TextPredictionDataset( |
| TRAIN_VOLS, modalities, tokenizer, |
| window_sec=args.window_sec, max_text_len=args.max_text_len, |
| downsample=args.downsample) |
| stats = train_ds.get_stats() |
| val_ds = TextPredictionDataset( |
| VAL_VOLS, modalities, tokenizer, |
| window_sec=args.window_sec, max_text_len=args.max_text_len, |
| downsample=args.downsample, stats=stats) |
| test_ds = TextPredictionDataset( |
| TEST_VOLS, modalities, tokenizer, |
| window_sec=args.window_sec, max_text_len=args.max_text_len, |
| downsample=args.downsample, stats=stats) |
|
|
| if len(train_ds) == 0: |
| print("ERROR: No training samples!", flush=True) |
| return None |
|
|
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, |
| shuffle=True, drop_last=False) |
| val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False) |
| test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False) |
|
|
| |
| model = SensorToTextModel( |
| train_ds.feat_dim, llm, tokenizer, |
| n_sensor_tokens=args.n_sensor_tokens, d_model=args.hidden_dim) |
| model = model.to(device) |
|
|
| |
| sensor_params = list(model.sensor_encoder.parameters()) + \ |
| list(model.projection.parameters()) |
| all_trainable = sensor_params + lora_params |
| trainable_count = sum(p.numel() for p in all_trainable) |
| total_count = sum(p.numel() for p in model.parameters()) |
| print(f"Trainable: {trainable_count:,} / Total: {total_count:,}", flush=True) |
|
|
| optimizer = torch.optim.AdamW([ |
| {'params': sensor_params, 'lr': args.lr}, |
| {'params': lora_params, 'lr': args.lr * 0.2}, |
| ], weight_decay=args.weight_decay) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, patience=7, factor=0.5, min_lr=1e-6) |
|
|
| mod_str = '-'.join(modalities) |
| exp_name = f"pred_llm_{mod_str}" |
| out_dir = os.path.join(args.output_dir, exp_name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| best_val_loss = float('inf') |
| best_epoch = 0 |
| patience_ctr = 0 |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| tr_loss = train_epoch(model, train_loader, optimizer, device) |
|
|
| if epoch % 5 == 0 or epoch <= 2 or patience_ctr >= args.patience - 2: |
| val_m, _, _ = eval_with_generation( |
| model, val_loader, tokenizer, device) |
| print(f" Epoch {epoch:3d} | TrLoss={tr_loss:.4f} | " |
| f"Val: loss={val_m['loss']:.4f} EM={val_m['exact_match']:.4f} " |
| f"charF1={val_m['char_f1']:.4f} | {time.time()-t0:.1f}s", |
| flush=True) |
| else: |
| val_loss = eval_loss_only(model, val_loader, device) |
| val_m = {'loss': val_loss} |
| print(f" Epoch {epoch:3d} | TrLoss={tr_loss:.4f} | " |
| f"Val: loss={val_loss:.4f} | {time.time()-t0:.1f}s", |
| flush=True) |
|
|
| scheduler.step(val_m['loss']) |
|
|
| if val_m['loss'] < best_val_loss: |
| best_val_loss = val_m['loss'] |
| best_epoch = epoch |
| patience_ctr = 0 |
| |
| save_sd = {} |
| for k, v in model.state_dict().items(): |
| if k.startswith('llm.'): |
| if 'lora_A' in k or 'lora_B' in k: |
| save_sd[k] = v |
| else: |
| save_sd[k] = v |
| torch.save(save_sd, os.path.join(out_dir, 'model_best.pt')) |
| else: |
| patience_ctr += 1 |
| if patience_ctr >= args.patience: |
| print(f" Early stopping at epoch {epoch}", flush=True) |
| break |
|
|
| |
| best_sd = torch.load(os.path.join(out_dir, 'model_best.pt'), |
| weights_only=True) |
| model.load_state_dict(best_sd, strict=False) |
| test_m, test_preds, test_refs = eval_with_generation( |
| model, test_loader, tokenizer, device) |
|
|
| print(f"\n--- Test (best epoch {best_epoch}) ---", flush=True) |
| for k, v in test_m.items(): |
| print(f" {k}: {v:.4f}", flush=True) |
|
|
| print("\nSample predictions:", flush=True) |
| indices = random.sample(range(len(test_preds)), min(15, len(test_preds))) |
| for i in indices: |
| tag = "OK" if test_preds[i].strip() == test_refs[i].strip() else "XX" |
| print(f" [{tag}] Pred: {test_preds[i].strip()}", flush=True) |
| print(f" Ref: {test_refs[i].strip()}", flush=True) |
|
|
| results = { |
| 'experiment': exp_name, |
| 'modalities': modalities, |
| 'best_epoch': best_epoch, |
| 'test_metrics': {k: float(v) for k, v in test_m.items()}, |
| 'trainable_params': trainable_count, |
| 'lora_params': lora_param_count, |
| 'train_samples': len(train_ds), |
| 'val_samples': len(val_ds), |
| 'test_samples': len(test_ds), |
| 'args': vars(args), |
| 'sample_predictions': [ |
| {'pred': test_preds[i].strip(), 'ref': test_refs[i].strip()} |
| for i in indices |
| ], |
| } |
| with open(os.path.join(out_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2, ensure_ascii=False) |
| print(f" Saved to {out_dir}", flush=True) |
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--modalities', type=str, default='imu') |
| parser.add_argument('--window_sec', type=float, default=15.0) |
| parser.add_argument('--llm_name', type=str, |
| default='${PULSE_ROOT}/models/qwen2.5-0.5b') |
| parser.add_argument('--lora_r', type=int, default=8) |
| parser.add_argument('--lora_alpha', type=int, default=16) |
| parser.add_argument('--n_sensor_tokens', type=int, default=8) |
| parser.add_argument('--max_text_len', type=int, default=48) |
| parser.add_argument('--epochs', type=int, default=50) |
| parser.add_argument('--batch_size', type=int, default=8) |
| parser.add_argument('--lr', type=float, default=5e-4) |
| parser.add_argument('--weight_decay', type=float, default=1e-4) |
| parser.add_argument('--hidden_dim', type=int, default=64) |
| parser.add_argument('--downsample', type=int, default=5) |
| parser.add_argument('--patience', type=int, default=15) |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--output_dir', type=str, |
| default='${PULSE_ROOT}/results/pred_llm2') |
| args = parser.parse_args() |
| os.makedirs(args.output_dir, exist_ok=True) |
| run_experiment(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|