""" TTRLVR Dataset for AZR Integration TTRLVR parquet 파일을 읽어 AZR 학습에 사용할 수 있도록 변환 """ import os import pandas as pd import numpy as np from typing import Dict, List, Any, Optional from torch.utils.data import Dataset from transformers import AutoTokenizer from .rl_dataset import RLHFDataset class TTRLVRDataset(RLHFDataset): """TTRLVR 데이터를 AZR 형식으로 로드하는 Dataset""" def __init__(self, parquet_files: str, tokenizer: AutoTokenizer, task_type: Optional[str] = None, **kwargs): """ Args: parquet_files: TTRLVR parquet 파일 경로 (또는 디렉토리) tokenizer: 토크나이저 task_type: 특정 task 타입만 로드 (induction/deduction/abduction) """ # parquet_files가 ListConfig인 경우 리스트로 변환 from omegaconf import ListConfig if isinstance(parquet_files, ListConfig): parquet_files = list(parquet_files) # parquet_files가 디렉토리인 경우 처리 if isinstance(parquet_files, str) and os.path.isdir(parquet_files): if task_type: parquet_files = os.path.join(parquet_files, f"{task_type}.parquet") else: # 모든 task 타입 파일 수집 files = [] for t in ['induction', 'deduction', 'abduction']: f = os.path.join(parquet_files, f"{t}.parquet") if os.path.exists(f): files.append(f) parquet_files = files super().__init__( parquet_files=parquet_files, tokenizer=tokenizer, prompt_key='prompt', # TTRLVR은 'prompt' 키 사용 **kwargs ) self.task_type = task_type def __getitem__(self, idx): """단일 샘플 반환""" # TTRLVR 특별 처리 - 원본 데이터에 먼저 접근 # RLHFDataset.__getitem__이 prompt를 pop하기 전에 원본 데이터 가져오기 if hasattr(self, 'dataframe') and hasattr(self.dataframe, 'iloc'): # pandas DataFrame인 경우 original_row = self.dataframe.iloc[idx].to_dict() # prompt 필드 백업 original_prompt = original_row.get('prompt', None) else: original_row = {} original_prompt = None # 기본 데이터 가져오기 (RLHFDataset.__getitem__ 호출) data = super().__getitem__(idx) # 백업한 prompt 사용 row = original_row # prompt 처리 - numpy array, list, dict 등 다양한 형태 처리 prompt_text = None if original_prompt is not None: if isinstance(original_prompt, np.ndarray): if len(original_prompt) > 0 and isinstance(original_prompt[0], dict): prompt_text = original_prompt[0].get('content', '') else: prompt_text = str(original_prompt) elif isinstance(original_prompt, list): if len(original_prompt) > 0 and isinstance(original_prompt[0], dict): prompt_text = original_prompt[0].get('content', '') else: prompt_text = str(original_prompt) elif isinstance(original_prompt, str): prompt_text = original_prompt else: prompt_text = str(original_prompt) # prompt를 data에 추가 (문자열로) if prompt_text is not None: data['prompt'] = prompt_text # TTRLVR 메타데이터 추가 ttrlvr_metadata = { 'task_type': self._extract_task_type(row), 'expected_solution': row.get('ground_truth', ''), 'problem': row.get('problem', {}), 'ipo_group_id': row.get('ipo_group_id', ''), 'uid': row.get('uid', ''), 'evaluation_data': self._prepare_evaluation_data(row) } # 메타데이터를 data에 추가 data['ttrlvr_metadata'] = ttrlvr_metadata return data def _extract_task_type(self, row: pd.Series) -> str: """행에서 task 타입 추출""" uid = row.get('uid', '') if 'induction' in uid: return 'induction' elif 'deduction' in uid: return 'deduction' elif 'abduction' in uid: return 'abduction' return 'unknown' def _prepare_evaluation_data(self, row: pd.Series) -> Dict[str, Any]: """Task 타입별 evaluation data 준비""" # 1. 먼저 저장된 evaluation_data가 있는지 확인 (Phase 1-4에서 생성된 데이터) if 'evaluation_data' in row and row['evaluation_data']: eval_data = row['evaluation_data'] # pandas/numpy 객체를 Python 네이티브 타입으로 변환 if hasattr(eval_data, 'item'): eval_data = eval_data.item() return eval_data if isinstance(eval_data, dict) else {} # 2. Fallback: problem 필드에서 구성 (호환성 유지) task_type = self._extract_task_type(row) problem = row.get('problem', {}) if task_type == 'induction': # IPO에서 input/output 쌍 추출 return { 'input_output_pairs': [ (problem.get('input', ''), problem.get('output', '')) ] } elif task_type == 'deduction': return { 'function_code': problem.get('snippet', ''), 'input': problem.get('input', '') } elif task_type == 'abduction': return { 'function_code': problem.get('snippet', ''), 'expected_output': problem.get('output', '') } return {}