| import json |
| import torch |
| import numpy as np |
| import random |
| from collections import defaultdict |
| from torch.utils.data import Dataset |
| from v0_core.data.utils import load_jsonl_lines |
|
|
| |
| |
| |
| class ValueModelDataset(Dataset): |
| def __init__(self, |
| context_paths, |
| query_paths, |
| prompt_dict_path, |
| label_strategy='binary', |
| query_batch_size=8, |
| support_size=256, |
| mode='train'): |
| """ |
| args: |
| context_paths: List of paths to context_pool jsonl files |
| query_paths: List of paths to query_pool jsonl files (train/test/validity) |
| prompt_dict_path: List of paths to prompt dictionaries |
| query_batch_size: Number of queries in one forward pass (all from same step) |
| support_size: Number of context samples to sample |
| mode: 'train' (shuffle queries before chunking) or 'eval' (sequential) |
| """ |
| self.label_strategy = label_strategy |
| self.query_batch_size = query_batch_size |
| self.support_size = support_size |
| self.mode = mode |
|
|
| |
| print(f"Loading prompts from {prompt_dict_path}...") |
| self.prompt_map = {} |
| for path in prompt_dict_path: |
| with open(path, 'r', encoding='utf-8') as f: |
| self.prompt_map.update(json.load(f)) |
|
|
| |
| |
| print("Loading Context Pool...") |
| self.context_pool = defaultdict(list) |
| self.context_pool_fallback = defaultdict(list) |
| raw_context = load_jsonl_lines(context_paths) |
| for item in raw_context: |
| key = (item['dataset'], item['model'], item['step']) |
| self.context_pool[key].append(item) |
| fallback_key = (item['model'], item['step']) |
| self.context_pool_fallback[fallback_key].append(item) |
| print(f"Loaded Context Pool with {len(self.context_pool)} unique (dataset, model, step) keys.") |
|
|
| |
| print(f"Loading Query Pool from {query_paths}...") |
| raw_queries = load_jsonl_lines(query_paths) |
| |
| |
| self.queries_by_key = defaultdict(list) |
| print("Grouping Queries...") |
| for item in raw_queries: |
| key = (item['dataset'], item['model'], item['step']) |
| |
| s_id_str = f"{item['dataset']}_{item['id']}" |
| item['text'] = self.prompt_map[s_id_str] |
| self.queries_by_key[key].append(item) |
| |
| |
| |
| print("Calculating Global Context Statistics for Re-weighting...") |
| self.context_stats = {} |
| if mode == 'train': |
| for key, items in self.queries_by_key.items(): |
| |
| n_pos = sum(1 for x in items if float(x.get('score', -1)) >= 0) |
| n_neg = len(items) - n_pos |
| self.context_stats[key] = {'n_pos': n_pos, 'n_neg': n_neg} |
| |
| print("\n" + "="*60) |
| print(f"Top 10 Steps Statistics ({mode} mode)") |
| print(f"{'Dataset':<15} | {'Model':<15} | {'Step':<6} | {'n_pos':<6} | {'n_neg':<6} | {'Total':<6}") |
| print("-" * 60) |
| |
| sorted_keys = sorted(list(self.context_stats.keys())) |
| |
| for i, key in enumerate(sorted_keys[:10]): |
|
|
| dataset_name, model_name, step_val = key |
| stats = self.context_stats[key] |
| total = stats['n_pos'] + stats['n_neg'] |
| print(f"{dataset_name:<15} | {model_name:<15} | {str(step_val):<6} | " |
| f"{stats['n_pos']:<6} | {stats['n_neg']:<6} | {total:<6}") |
| |
| print(f"... (Total {len(sorted_keys)} steps loaded)") |
| print("="*60 + "\n") |
| |
| |
| self.tasks = [] |
| self.generate_tasks(shuffle=(self.mode == 'train')) |
| |
| print(f"Dataset Initialized. Total Tasks: {len(self.tasks)}") |
|
|
| def generate_tasks(self, shuffle=True): |
| """ |
| Pairwise Task Generation with Cyclic Oversampling. |
| 目标:保留所有样本,不进行丢弃。对于数量较少的一方,循环重复使用以匹配数量较多的一方。 |
| """ |
| new_tasks = [] |
| keys = sorted(list(self.queries_by_key.keys())) |
| |
| if shuffle: |
| random.shuffle(keys) |
|
|
| dropped_steps = 0 |
| total_pairs = 0 |
|
|
| for key in keys: |
| samples = list(self.queries_by_key[key]) |
| |
| if self.mode == 'train': |
| |
| pos_list = [x for x in samples if self._process_label(x['score']) >= 0.5] |
| neg_list = [x for x in samples if self._process_label(x['score']) < 0.5] |
| |
| n_pos = len(pos_list) |
| n_neg = len(neg_list) |
|
|
| |
| if n_pos == 0 or n_neg == 0: |
| dropped_steps += 1 |
| continue |
| |
| |
| if shuffle: |
| random.shuffle(pos_list) |
| random.shuffle(neg_list) |
| |
| |
| |
| n_pairs = max(n_pos, n_neg) |
| |
| paired_samples = [] |
| for i in range(n_pairs): |
| p = pos_list[i % n_pos] |
| n = neg_list[i % n_neg] |
| |
| paired_samples.append(p) |
| paired_samples.append(n) |
| |
| total_pairs += n_pairs |
| |
| |
| |
| bs = self.query_batch_size |
| if bs % 2 != 0: |
| bs -= 1 |
| if bs < 2: bs = 2 |
|
|
| for i in range(0, len(paired_samples), bs): |
| chunk = paired_samples[i : i + bs] |
| |
| |
| if len(chunk) % 2 != 0: |
| chunk = chunk[:-1] |
| |
| context_key_to_use = None |
| if key in self.context_pool and len(self.context_pool[key]) > 0: |
| context_key_to_use = key |
| else: |
| fallback_key = (key[1], key[2]) |
| if fallback_key in self.context_pool_fallback and len(self.context_pool_fallback[fallback_key]) > 0: |
| context_key_to_use = fallback_key |
|
|
| if len(chunk) > 0 and context_key_to_use is not None: |
| new_tasks.append({ |
| 'key': key, |
| 'context_key': context_key_to_use, |
| 'queries': chunk, |
| 'is_pairwise': True |
| }) |
|
|
| else: |
| if shuffle: random.shuffle(samples) |
| for i in range(0, len(samples), self.query_batch_size): |
| chunk = samples[i : i + self.query_batch_size] |
| context_key_to_use = None |
| if key in self.context_pool and len(self.context_pool[key]) > 0: |
| context_key_to_use = key |
| else: |
| fallback_key = (key[1], key[2]) |
| if fallback_key in self.context_pool_fallback and len(self.context_pool_fallback[fallback_key]) > 0: |
| context_key_to_use = fallback_key |
|
|
| if context_key_to_use is not None: |
| new_tasks.append({ |
| 'key': key, |
| 'context_key': context_key_to_use, |
| 'queries': chunk, |
| 'is_pairwise': False |
| }) |
|
|
| self.tasks = new_tasks |
| if self.mode == 'train': |
| print(f" >>> [Dataset] Generated {len(self.tasks)} tasks from {len(keys)} contexts.") |
| print(f" >>> [Pairwise Stats] Total Pairs: {total_pairs} (Using Oversampling). Dropped Steps (0 pos or 0 neg): {dropped_steps}") |
|
|
| def _process_label(self, reward): |
| val = float(reward) |
| if self.label_strategy == "binary": |
| return 1.0 if val >= 0 else 0.0 |
| elif self.label_strategy == "minmax_norm": |
| return (np.clip(val, -1.0, 1.0) + 1.0) / 2.0 |
| return val |
|
|
| def __len__(self): |
| return len(self.tasks) |
|
|
| def __getitem__(self, idx): |
| task = self.tasks[idx] |
| key = task['key'] |
| query_samples = task['queries'] |
|
|
| |
| context_key = task.get('context_key', key) |
| available_context = self.context_pool[key] if context_key == key else self.context_pool_fallback[context_key] |
| |
| if len(available_context) >= self.support_size: |
| support_samples = random.sample(available_context, self.support_size) |
| else: |
| support_samples = available_context |
|
|
| |
| prompts = [] |
| labels = [] |
| |
| |
| for item in support_samples: |
| s_id_str = f"{item['dataset']}_{item['id']}" |
| text = self.prompt_map[s_id_str] |
| if text: |
| prompts.append(text) |
| labels.append(self._process_label(item['score'])) |
| |
| split_idx = len(prompts) |
| |
| |
| q_ids = [] |
| pair_ids = [] |
| pair_types = [] |
|
|
| for item in query_samples: |
| prompts.append(item['text']) |
| labels.append(self._process_label(item['score'])) |
| q_ids.append(item['id']) |
| if 'pair_id' in item: |
| pair_ids.append(item['pair_id']) |
| if 'pair_type' in item: |
| pair_types.append(item['pair_type']) |
|
|
| |
| stats = self.context_stats.get(key, {'n_pos': 0, 'n_neg': 0}) |
|
|
| return { |
| "prompts": prompts, |
| "labels": torch.tensor(labels, dtype=torch.float), |
| "split_idx": split_idx, |
| "q_ids": q_ids, |
| "pair_ids": pair_ids, |
| "pair_types": pair_types, |
| "key": key, |
| "stats": stats |
| } |
|
|