hjkim00's picture
Restore all essential files - code, configs, and MBPP/HumanEval data
24c2665 verified
"""
Complete TestTime RLVR Pipeline
LLM ์†”๋ฃจ์…˜ ์ƒ์„ฑ โ†’ IPO ์ถ”์ถœ โ†’ ํƒœ์Šคํฌ ์ƒ์„ฑ โ†’ LLM ํ‰๊ฐ€ โ†’ Reward ๊ณ„์‚ฐ
๋ชจ๋“  ๋‹จ๊ณ„์—์„œ AZR ์ฝ”๋“œ๋ฅผ ์ตœ๋Œ€ํ•œ ๊ทธ๋Œ€๋กœ ํ™œ์šฉ
"""
from typing import Dict, List, Any, Optional
import torch
import re
import os
import json
import ray
import math
from pathlib import Path
from datetime import datetime
from .benchmark_loader import BenchmarkProblemLoader
from .solution_generator import InitialSolutionGenerator
from .ipo_extractor import IPOTripleExtractor, IPOBuffer
from .task_generator import TestTimeTaskGenerator
from .config import TestTimeConfig, BenchmarkConfig
# Ray Actor ์ œ๊ฑฐ - VLLM ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ ์‚ฌ์šฉ
from .logger import TestTimeLogger
# AZR Reward Manager ์ง์ ‘ ์‚ฌ์šฉ
from ..rewards.reward_managers import CodeIORewardManager
@ray.remote
class RemoteTestTimePipeline:
"""Ray Actor๋กœ ์ž‘๋™ํ•˜๋Š” TestTime Pipeline (VeRL ํŒจํ„ด)"""
def __init__(self, config: TestTimeConfig, model_path: str):
"""Ray worker ๋‚ด๋ถ€์—์„œ ๋ชจ๋ธ ๋กœ๋”ฉ"""
self.config = config
self.model_path = model_path
# Ray worker์—์„œ VLLM ๋ชจ๋ธ ๋กœ๋”ฉ
from .solution_generator import InitialSolutionGenerator
import os
# Ray runtime_env์—์„œ ์‰ผํ‘œ๊ฐ€ ์ž˜๋ฆฌ๋Š” ๋ฌธ์ œ ํ•ด๊ฒฐ
# VLLM_USE_SPECIFIC_GPUS๊ฐ€ ์„ค์ •๋˜์–ด ์žˆ์œผ๋ฉด ๊ทธ๊ฑธ ์‚ฌ์šฉ
if 'VLLM_USE_SPECIFIC_GPUS' in os.environ:
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['VLLM_USE_SPECIFIC_GPUS']
print(f"[RemoteTestTimePipeline] Restored CUDA_VISIBLE_DEVICES from VLLM_USE_SPECIFIC_GPUS: {os.environ['CUDA_VISIBLE_DEVICES']}")
# GPU ์„ค์ •
device = 'cuda:0'
if 'CUDA_VISIBLE_DEVICES' in os.environ:
device = f"cuda:0"
# ๋ฉ€ํ‹ฐ GPU ํ™˜๊ฒฝ์—์„œ๋Š” VLLM ์‚ฌ์šฉ
cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
print(f"[RemoteTestTimePipeline] CUDA_VISIBLE_DEVICES: {cuda_devices}")
# config์—์„œ ๋ช…์‹œ์ ์œผ๋กœ use_vllm_for_data_generation ์„ค์ • ํ™•์ธ
use_vllm = getattr(config, 'use_vllm_for_data_generation', len(cuda_devices.split(',')) > 1)
gpu_count = len(cuda_devices.split(','))
# Step 5 (VeRL)์™€ GPU ๊ณต์œ ๋ฅผ ์œ„ํ•ด Step 1-4๋Š” VLLM 2๊ฐœ GPU๋งŒ ์‚ฌ์šฉ
vllm_tensor_parallel_size = min(2, gpu_count) if use_vllm else 1
print(f"[RemoteTestTimePipeline] GPU count: {gpu_count}, use_vllm: {use_vllm}, tensor_parallel_size: {vllm_tensor_parallel_size}")
self.model, self.tokenizer = InitialSolutionGenerator.load_model_with_optimizations(
model_path, device, config, use_vllm=use_vllm, tensor_parallel_size=vllm_tensor_parallel_size
)
# ๋กœ๊ฑฐ ์„ค์ • - Ray worker์—์„œ๋„ ๋™์ผํ•œ ๋กœ๊ทธ ํŒŒ์ผ ์‚ฌ์šฉ
import os
log_file = os.environ.get('TTRLVR_LOG_FILE', None)
if log_file:
self.logger = TestTimeLogger(log_file=log_file)
else:
self.logger = TestTimeLogger()
# CompleteTestTimePipeline ์ดˆ๊ธฐํ™”
self.pipeline = CompleteTestTimePipeline(
model=self.model,
tokenizer=self.tokenizer,
config=config,
logger=self.logger
)
def run_complete_pipeline(self, benchmark_config: BenchmarkConfig,
problem_id: str, round_num: int = 1, session_timestamp: str = None,
output_base_dir: str = None) -> Dict[str, Any]:
"""์›๊ฒฉ์—์„œ ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰"""
return self.pipeline.run_complete_pipeline(benchmark_config, problem_id, round_num, session_timestamp, output_base_dir)
def generate_batch_vllm(self, prompts: List[str], max_tokens: int = 512,
temperature: float = 0.7, top_p: float = 1.0, n: int = 1) -> Dict[str, Any]:
"""
VeRL์—์„œ ํ˜ธ์ถœํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฐ์น˜ ์ƒ์„ฑ ๋ฉ”์„œ๋“œ
Step 5์˜ SharedVLLMRollout์—์„œ ์‚ฌ์šฉ
"""
from .solution_generator import InitialSolutionGenerator
# VLLM ๋ฐฐ์น˜ ์ƒ์„ฑ
if hasattr(self.model, 'generate'):
# VLLM ๋ชจ๋ธ
outputs = InitialSolutionGenerator.generate_batch_vllm(
self.model, prompts,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
n=n
)
# ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
responses = []
for output in outputs:
generated_text = output.outputs[0].text
responses.append(generated_text)
# Tokenizer๋กœ input_ids ์ƒ์„ฑ (VeRL์ด ํ•„์š”๋กœ ํ•จ)
tokenized = self.tokenizer(prompts, padding=True, truncation=True, return_tensors="pt")
return {
'responses': responses,
'input_ids': tokenized['input_ids'].tolist(),
'attention_mask': tokenized['attention_mask'].tolist()
}
else:
# HuggingFace ๋ชจ๋ธ (fallback)
raise NotImplementedError("HuggingFace batch generation not implemented for VeRL sharing")
def update_model_weights(self, model_path: str) -> bool:
"""
ํ•™์Šต๋œ ๋ชจ๋ธ๋กœ VLLM ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ
๋งค ๋ผ์šด๋“œ ํ›„ ํ˜ธ์ถœ๋จ
"""
try:
self.logger.log_info(f"๐Ÿ”„ Updating VLLM weights from: {model_path}")
# VLLM์€ ๋™์  ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ๋ฅผ ์ง€์›ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ
# ์ƒˆ๋กœ์šด ์—”์ง„์œผ๋กœ ๊ต์ฒดํ•ด์•ผ ํ•จ
from .solution_generator import InitialSolutionGenerator
import os
device = 'cuda:0'
use_vllm = len(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(',')) > 1
gpu_count = len(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(','))
vllm_tensor_parallel_size = min(2, gpu_count) if use_vllm else 1
# ๊ธฐ์กด VLLM ์—”์ง„ ์ •๋ฆฌ
if hasattr(self.model, 'llm_engine'):
del self.model
import torch
torch.cuda.empty_cache()
self.logger.log_info(" - Old VLLM engine cleaned up")
# ์ƒˆ๋กœ์šด ๋ชจ๋ธ ๋กœ๋“œ
self.model, _ = InitialSolutionGenerator.load_model_with_optimizations(
model_path, device, self.config,
use_vllm=use_vllm,
tensor_parallel_size=vllm_tensor_parallel_size
)
# Pipeline ์ธ์Šคํ„ด์Šค ์—…๋ฐ์ดํŠธ
self.pipeline.model = self.model
self.pipeline.solution_generator.model = self.model
self.logger.log_info(f"โœ… VLLM weights updated successfully")
return True
except Exception as e:
self.logger.log_error(f"Failed to update VLLM weights: {e}")
return False
def update_model_weights_from_state_dict(self, state_dict: Dict[str, Any]) -> bool:
"""
State dict๋กœ ์ง์ ‘ ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ (๋” ํšจ์œจ์ )
VeRL์—์„œ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์ง์ ‘ ์ „๋‹ฌ๋ฐ›์•„ ์—…๋ฐ์ดํŠธ
"""
try:
self.logger.log_info("๐Ÿ”„ Updating VLLM weights from state dict")
# VLLM์€ ๋™์  ์—…๋ฐ์ดํŠธ๋ฅผ ์ง€์›ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ์ด ๋ฐฉ๋ฒ•์€ ์ œํ•œ์ 
# HuggingFace ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ์—๋งŒ ๊ฐ€๋Šฅ
if not hasattr(self.model, 'llm_engine'):
# HuggingFace ๋ชจ๋ธ
self.model.load_state_dict(state_dict)
self.logger.log_info("โœ… Model weights updated via state dict")
return True
else:
# VLLM์€ ํŒŒ์ผ ๊ธฐ๋ฐ˜ ๋กœ๋“œ๋งŒ ์ง€์›
self.logger.log_warning("โš ๏ธ VLLM requires file-based weight loading")
return False
except Exception as e:
self.logger.log_error(f"Failed to update weights from state dict: {e}")
return False
def cleanup(self):
"""Ray Actor ์ข…๋ฃŒ ์ „ ๋ฆฌ์†Œ์Šค ์ •๋ฆฌ"""
try:
self.logger.log_info("๐Ÿงน Cleaning up RemoteTestTimePipeline resources...")
# VLLM ๋ชจ๋ธ์ด ์žˆ์œผ๋ฉด ์ •๋ฆฌ
if hasattr(self, 'model') and self.model is not None:
self.logger.log_info(" - Cleaning up VLLM model...")
# VLLM ์ธ์Šคํ„ด์Šค ์‚ญ์ œ
del self.model
self.model = None
# Pipeline ์ •๋ฆฌ
if hasattr(self, 'pipeline') and self.pipeline is not None:
if hasattr(self.pipeline, 'cleanup'):
self.pipeline.cleanup()
del self.pipeline
self.pipeline = None
# GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Garbage collection ๊ฐ•์ œ ์‹คํ–‰
import gc
gc.collect()
self.logger.log_info("โœ… Cleanup completed")
return True
except Exception as e:
self.logger.log_error(f"โš ๏ธ Cleanup error: {e}")
return False
class CompleteTestTimePipeline:
"""์™„์ „ํ•œ TestTime RLVR ํŒŒ์ดํ”„๋ผ์ธ"""
def __init__(self, model, tokenizer, config: TestTimeConfig,
logger: Optional[TestTimeLogger] = None):
self.model = model
self.tokenizer = tokenizer
self.config = config
# ๋กœ๊ฑฐ ์ดˆ๊ธฐํ™” ์‹œ task_output_dir ์ „๋‹ฌ (๋ผ์šด๋“œ๋ณ„ ๊ฒฝ๋กœ๋Š” run() ํ˜ธ์ถœ ์‹œ ์„ค์ •)
self.logger = logger or TestTimeLogger()
# ๊ฐ ์ปดํฌ๋„ŒํŠธ ์ดˆ๊ธฐํ™”
self.benchmark_loader = BenchmarkProblemLoader(config, self.logger)
# ๋ชจ๋ธ์ด None์ด ์•„๋‹Œ ๊ฒฝ์šฐ์—๋งŒ ์ปดํฌ๋„ŒํŠธ ์ดˆ๊ธฐํ™”
if model is not None and tokenizer is not None:
# ์—”์ง„ ์„ ํƒ ์„ค์ • (config์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ, ๊ธฐ๋ณธ๊ฐ’: VLLM)
use_vllm = getattr(config, 'use_vllm_for_data_generation', True)
self.solution_generator = InitialSolutionGenerator(model, tokenizer, config, self.logger, use_vllm=use_vllm)
self.ipo_extractor = IPOTripleExtractor(config, self.logger, model, tokenizer)
# IPO extractor์— solution generator ์ฐธ์กฐ ์ „๋‹ฌ (๋ฐฐ์น˜ ์ฒ˜๋ฆฌ์šฉ)
self.ipo_extractor.solution_generator = self.solution_generator
self.reward_manager = self._setup_azr_reward_manager()
else:
# Lazy initialization ํ”Œ๋ž˜๊ทธ
self.solution_generator = None
self.ipo_extractor = None
self.reward_manager = None
self.task_generator = TestTimeTaskGenerator(config, self.logger)
# ์‹คํ–‰ ๋ชจ๋“œ ์„ค์ •
self.execution_mode = "single_gpu" # ๊ธฐ๋ณธ๊ฐ’, iterative_trainer์—์„œ ์„ค์ •๋จ
self.available_gpus = []
# IPO Buffer ์ดˆ๊ธฐํ™”
self.ipo_buffer = IPOBuffer()
# Task output directory ์„ค์ •
self.task_output_dir = Path('./tmp/batch_results')
def _ensure_models_loaded(self):
"""๋ชจ๋ธ๊ณผ ์ปดํฌ๋„ŒํŠธ๋“ค์ด ๋กœ๋“œ๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜๊ณ  ํ•„์š”์‹œ ์ดˆ๊ธฐํ™”"""
if self.model is None or self.tokenizer is None:
raise RuntimeError("Model and tokenizer must be provided during initialization")
# ์ปดํฌ๋„ŒํŠธ๋“ค์ด None์ธ ๊ฒฝ์šฐ ์ดˆ๊ธฐํ™”
if self.solution_generator is None:
# ์—”์ง„ ์„ ํƒ ์„ค์ • (config์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ, ๊ธฐ๋ณธ๊ฐ’: VLLM)
use_vllm = getattr(self.config, 'use_vllm_for_data_generation', True)
self.logger.log_info(f"๐Ÿ”ง Initializing solution generator with use_vllm={use_vllm}")
try:
self.solution_generator = InitialSolutionGenerator(
self.model, self.tokenizer, self.config, self.logger, use_vllm=use_vllm
)
self.logger.log_info(f"โœ… Solution generator initialized successfully")
except Exception as e:
self.logger.log_error(f"โŒ Failed to initialize solution generator: {e}")
import traceback
self.logger.log_error(f"Traceback: {traceback.format_exc()}")
raise
if self.ipo_extractor is None:
self.ipo_extractor = IPOTripleExtractor(
self.config, self.logger, self.model, self.tokenizer
)
# IPO extractor์— solution generator ์ฐธ์กฐ ์ „๋‹ฌ
self.ipo_extractor.solution_generator = self.solution_generator
if self.reward_manager is None:
self.reward_manager = self._setup_azr_reward_manager()
self.logger.log_info("โœ… All components ready")
def set_execution_mode(self, execution_mode: str, available_gpus: List[int]):
"""์‹คํ–‰ ๋ชจ๋“œ ์„ค์ • (iterative_trainer์—์„œ ํ˜ธ์ถœ)"""
self.execution_mode = execution_mode
self.available_gpus = available_gpus
self.logger.log_info(f"๐ŸŽฏ Execution mode set to: {execution_mode}")
self.logger.log_info(f"๐ŸŽฏ Available GPUs: {available_gpus}")
def _setup_azr_reward_manager(self) -> CodeIORewardManager:
"""AZR Reward Manager ์„ค์ • (๊ธฐ์กด ์„ค์ • ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ)"""
# AZR์—์„œ ์‚ฌ์šฉํ•˜๋Š” ์„ค์ •์œผ๋กœ ์ดˆ๊ธฐํ™”
class SimpleConfig:
def __init__(self):
self.use_original_code_as_ref = False
self.reward_type = 'code_execution'
self.weight = 1.0
reward_manager = CodeIORewardManager(
tokenizer=self.tokenizer,
num_examine=0,
reward_fn_extraction_type='rule',
math_metric='accuracy',
split='test',
splitter='boxed',
output_path='./testtime_output',
max_prompt_length=1024,
generation_reward_config=SimpleConfig()
)
return reward_manager
def run_complete_pipeline(self, benchmark_config: BenchmarkConfig,
problem_id: str, round_num: int = 1, session_timestamp: str = None,
output_base_dir: str = None) -> Dict[str, Any]:
"""์™„์ „ํ•œ ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰
Args:
benchmark_config: ๋ฒค์น˜๋งˆํฌ ์„ค์ •
problem_id: ๋ฌธ์ œ ID
round_num: ๋ผ์šด๋“œ ๋ฒˆํ˜ธ
session_timestamp: ์„ธ์…˜ ํƒ€์ž„์Šคํƒฌํ”„
output_base_dir: ๋กœ๊ทธ ์ €์žฅ ๊ธฐ๋ณธ ๋””๋ ‰ํ† ๋ฆฌ (None์ด๋ฉด ๊ธฐ๋ณธ ๊ฒฝ๋กœ ์‚ฌ์šฉ)
"""
# ์„ค๊ณ„๋œ ๋””๋ ‰ํ† ๋ฆฌ ๊ตฌ์กฐ์— ๋งž๋Š” ๋กœ๊ฑฐ ์žฌ์„ค์ •
if session_timestamp is None:
# ๋…๋ฆฝ ์‹คํ–‰ ์‹œ ์ƒˆ timestamp ์ƒ์„ฑ
session_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
benchmark_safe = benchmark_config.name
problem_safe = problem_id.replace('/', '_')
# output_base_dir์ด ์ œ๊ณต๋˜๋ฉด ์‚ฌ์šฉ, ์•„๋‹ˆ๋ฉด ๊ธฐ๋ณธ ๊ฒฝ๋กœ ์‚ฌ์šฉ
if output_base_dir:
round_log_dir = os.path.join(output_base_dir, benchmark_safe, problem_safe, f'round_{round_num}')
else:
round_log_dir = f'/home/ubuntu/RLVR/TestTime-RLVR-v2/tmp/batch_results/ttrlvr_azr_{session_timestamp}/{benchmark_safe}/{problem_safe}/round_{round_num}'
# ์ƒˆ ๋กœ๊ฑฐ๋กœ ์žฌ์„ค์ • (์„ค๊ณ„๋œ ๊ตฌ์กฐ ์‚ฌ์šฉ)
self.logger = TestTimeLogger(task_output_dir=round_log_dir)
self.logger.log_info(f"๐Ÿš€ Starting complete TestTime RLVR pipeline for {problem_id}")
# ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์œผ๋ฉด ๋กœ๋“œ
self._ensure_models_loaded()
pipeline_result = {
'problem_id': problem_id,
'benchmark': benchmark_config.name,
'round': round_num,
'output_dir': round_log_dir,
'steps': {},
'success': False,
'error': None
}
try:
# Step 1: ๋ฒค์น˜๋งˆํฌ ๋ฌธ์ œ ๋กœ๋”ฉ
self.logger.log_info("๐Ÿ“„ Step 1: Loading benchmark problem")
problem = self.benchmark_loader.load_problem(benchmark_config, problem_id)
pipeline_result['steps']['problem_loading'] = {
'success': True,
'problem': problem
}
# Step 1.5: ๋ฒ ์ด์Šค๋ผ์ธ ์„ฑ๋Šฅ ์ธก์ • (NEW)
self.logger.log_info("๐Ÿ“Š Step 1.5: Baseline performance evaluation")
baseline_results = self._evaluate_baseline_performance(problem)
pipeline_result['steps']['baseline_evaluation'] = baseline_results
# ๐Ÿ”„ ๋ผ์šด๋“œ๋ณ„ IPO buffer ์ดˆ๊ธฐํ™” (๊ฐ ๋ผ์šด๋“œ๋Š” ๋…๋ฆฝ์ )
self.logger.log_info(f"๐Ÿ”„ Clearing IPO buffer for round {round_num}")
self.ipo_buffer.clear(problem_id)
# Step 2: ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ ๋ฐ IPO ์ฒ˜๋ฆฌ (NEW)
diverse_programs_results = self._generate_diverse_programs_and_ipo(problem)
pipeline_result['steps']['diverse_programs'] = diverse_programs_results
# Diverse programs ํ‰๊ฐ€ ๊ฒฐ๊ณผ ์ €์žฅ
self._save_diverse_programs_evaluation(problem, diverse_programs_results)
if not diverse_programs_results['success']:
self.logger.log_error(f"โŒ No valid diverse programs generated")
pipeline_result['error'] = "No valid diverse programs could be generated"
return pipeline_result
# Step 3: ํ˜„์žฌ ๋ผ์šด๋“œ์—์„œ ์ƒ์„ฑ๋œ IPO triples๋กœ๋งŒ ํƒœ์Šคํฌ ์ƒ์„ฑ
self.logger.log_info("๐ŸŽฏ Step 3: Generating tasks from current round IPO triples")
current_round_triples = self.ipo_buffer.get_all(problem_id)
self.logger.log_info(f"๐ŸŽฏ Using {len(current_round_triples)} IPO triples from current round")
all_tasks = self.task_generator.generate_tasks(current_round_triples, problem_id, round_num)
total_tasks = sum(len(tasks) for tasks in all_tasks.values())
pipeline_result['steps']['task_generation'] = {
'success': total_tasks > 0,
'total_tasks': total_tasks,
'tasks_by_type': {k: len(v) for k, v in all_tasks.items()},
'all_tasks': all_tasks
}
# Step 4: LLM์œผ๋กœ ํƒœ์Šคํฌ ํ‰๊ฐ€ (์Šคํ‚ต ๊ฐ€๋Šฅ)
if getattr(self.config, 'skip_task_evaluation', False):
self.logger.log_info("โญ๏ธ Step 4: Skipping task evaluation (fast mode)")
task_evaluations = {task_type: [] for task_type in all_tasks.keys()}
pipeline_result['steps']['task_evaluation'] = {
'success': True,
'skipped': True,
'evaluations': task_evaluations
}
else:
self.logger.log_info("๐Ÿ’ญ Step 4: Evaluating tasks with LLM")
task_evaluations = self._evaluate_tasks_with_llm(all_tasks)
pipeline_result['steps']['task_evaluation'] = {
'success': True,
'evaluations': task_evaluations
}
# Step 5: Reward ๊ณ„์‚ฐ (AZR Reward Manager ์‚ฌ์šฉ)
self.logger.log_info("๐Ÿ† Step 5: Computing rewards")
buffered_triples = self.ipo_buffer.get_all(problem_id)
rewards = self._compute_rewards_with_azr(task_evaluations, buffered_triples)
pipeline_result['steps']['reward_computation'] = {
'success': True,
'rewards': rewards
}
# Step 6: AZR ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ ์ €์žฅ
self.logger.log_info("๐Ÿ’พ Step 6: Saving AZR training data")
output_dir = pipeline_result.get('output_dir', './testtime_output')
azr_files = self._save_azr_training_data(all_tasks, problem_id, round_num, output_dir)
pipeline_result['steps']['azr_data_saving'] = {
'success': len(azr_files) > 0,
'files': azr_files,
'total_tasks': sum(len(tasks) for tasks in all_tasks.values())
}
# Step 7: Summary ํŒŒ์ผ ์ƒ์„ฑ (batch evaluation๊ณผ ๋™์ผํ•œ ํ˜•์‹)
self.logger.log_info("๐Ÿ“‹ Step 7: Generating task summary")
self._save_task_summary_json(problem, baseline_results, task_evaluations, round_num)
# ์ „์ฒด ์„ฑ๊ณต
pipeline_result['success'] = True
pipeline_result['azr_training_data'] = azr_files # AZR ๋ฐ์ดํ„ฐ ํŒŒ์ผ ๊ฒฝ๋กœ ์ถ”๊ฐ€
self.logger.log_info("โœ… Complete pipeline executed successfully")
return pipeline_result
except Exception as e:
self.logger.log_error(f"๐Ÿ’ฅ Pipeline failed: {e}")
pipeline_result['error'] = str(e)
return pipeline_result
finally:
# ๋ฆฌ์†Œ์Šค ์ •๋ฆฌ
self.ipo_extractor.cleanup()
def _evaluate_tasks_with_llm(self, all_tasks: Dict[str, List[Dict[str, Any]]]) -> Dict[str, List[Dict[str, Any]]]:
"""LLM์œผ๋กœ ์ƒ์„ฑ๋œ ํƒœ์Šคํฌ๋“ค ํ‰๊ฐ€ํ•˜๊ณ  basic_accuracy ์—…๋ฐ์ดํŠธ"""
evaluations = {}
# ์ •ํ™•๋„ ๊ณ„์‚ฐ์šฉ executor ์ดˆ๊ธฐํ™”
from ..utils.code_utils.python_executor import PythonExecutor
executor = PythonExecutor()
for task_type, tasks in all_tasks.items():
self.logger.log_info(f"๐Ÿ”„ Evaluating {len(tasks)} {task_type} tasks")
task_evaluations = []
for task in tasks:
# LLM์œผ๋กœ ํƒœ์Šคํฌ ํ•ด๊ฒฐ
task_prompt = task['prompt']
# AZR ๋ฐฉ์‹์œผ๋กœ ์ƒ์„ฑ
llm_response = self._generate_task_response(task_prompt)
# ํ‰๊ฐ€ ๊ฒฐ๊ณผ ์ €์žฅ
evaluation = {
'task_id': task['task_id'],
'task_type': task_type,
'prompt': task_prompt,
'llm_response': llm_response,
'expected_solution': task['expected_solution'],
'evaluation_data': task['evaluation_data']
}
# ๐Ÿ†• ์ •ํ™•๋„ ๊ณ„์‚ฐ ๋ฐ task ์—…๋ฐ์ดํŠธ
accuracy = self._calculate_task_accuracy(evaluation, task_type, executor)
task['basic_accuracy'] = accuracy # ์›๋ณธ task ๊ฐ์ฒด ์—…๋ฐ์ดํŠธ
evaluation['basic_accuracy'] = accuracy # evaluation์—๋„ ์ถ”๊ฐ€
task_evaluations.append(evaluation)
evaluations[task_type] = task_evaluations
# LLM ์‘๋‹ต ์ €์žฅ
self._save_llm_responses(task_type, task_evaluations)
return evaluations
def _generate_task_response(self, prompt: str) -> str:
"""๋‹จ์ผ ํƒœ์Šคํฌ์— ๋Œ€ํ•œ LLM ์‘๋‹ต ์ƒ์„ฑ (AZR ๋ฐฉ์‹)"""
# VLLM ์‚ฌ์šฉ ์—ฌ๋ถ€ ํ™•์ธ
try:
from vllm import LLM
if isinstance(self.model, LLM):
# VLLM ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature=0.05,
max_tokens=512,
top_p=0.95,
stop=["\n\n\n", "# Task:", "================================================================================"] # ๋” ๊ตฌ์ฒด์ ์ธ stop token
)
outputs = self.model.generate([prompt], sampling_params, use_tqdm=False)
response = outputs[0].outputs[0].text.replace("\t", " ")
return response.strip()
except ImportError:
pass
# HuggingFace ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ
inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096)
# attention mask ๋ช…์‹œ์ ์œผ๋กœ ์„ค์ •
if 'attention_mask' not in inputs:
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
if torch.cuda.is_available():
torch.cuda.empty_cache()
# AZR evaluation๊ณผ ๋™์ผํ•œ ์„ค์ • (VLLM๊ณผ ๋™์ผํ•œ temperature ์‚ฌ์šฉ)
outputs = self.model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'], # attention mask ๋ช…์‹œ์ ์œผ๋กœ ์ „๋‹ฌ
max_new_tokens=256, # ํƒœ์Šคํฌ ์‘๋‹ต์šฉ์œผ๋กœ ์ ๋‹นํ•œ ๊ธธ์ด
do_sample=True, # sampling ํ™œ์„ฑํ™”
temperature=0.05, # VLLM๊ณผ ๋™์ผํ•œ temperature
top_p=0.95, # VLLM๊ณผ ๋™์ผํ•œ top_p
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt):].strip()
return response
def _calculate_task_accuracy(self, evaluation: Dict[str, Any], task_type: str, executor) -> float:
"""๊ฐœ๋ณ„ task์˜ ์ •ํ™•๋„ ๊ณ„์‚ฐ"""
try:
llm_response = evaluation['llm_response']
expected = evaluation['expected_solution']
evaluation_data = evaluation['evaluation_data']
# AZR ๋ฐฉ์‹์œผ๋กœ ๋‹ต๋ณ€ ์ถ”์ถœ
extracted_answer = self._extract_answer_by_task_type(llm_response, task_type)
if task_type == 'abduction':
# Abduction: LLM์ด ์ƒ์„ฑํ•œ input์„ function์— ์ž…๋ ฅ โ†’ ๊ฒฐ๊ณผ์™€ expected output ๋น„๊ต
code = evaluation_data['function_code']
expected_output_value = evaluation_data['expected_output']
agent_input = extracted_answer
try:
# ํ•จ์ˆ˜๋ช…์„ f๋กœ ๋ณ€๊ฒฝ (EVAL_INPUT_PREDICTION_TEMPLATE์ด f๋ฅผ ๊ธฐ๋Œ€ํ•จ)
import re
func_name_match = re.search(r'def\s+(\w+)\s*\(', code)
if func_name_match:
original_func_name = func_name_match.group(1)
# ํ•จ์ˆ˜๋ช…์„ f๋กœ ๋ณ€๊ฒฝ
code = re.sub(r'def\s+' + re.escape(original_func_name) + r'\s*\(', 'def f(', code)
from ..utils.code_utils.templates import EVAL_INPUT_PREDICTION_TEMPLATE
code_snippet = EVAL_INPUT_PREDICTION_TEMPLATE.format(
code=code,
gold_output=expected_output_value,
agent_input=agent_input
)
result, status = executor.apply(code_snippet)
if 'error' in status.lower():
accuracy = 0.0
else:
try:
if isinstance(result, bool):
agent_output = result
else:
agent_output = eval(result)
accuracy = 1.0 if agent_output else 0.0
except:
accuracy = 0.0
except:
accuracy = 0.0
elif task_type == 'deduction':
# Deduction: LLM์ด ์ƒ์„ฑํ•œ output์„ expected output๊ณผ ๋น„๊ต
expected_output = expected
agent_output = extracted_answer
try:
accuracy = 1.0 if eval(expected_output) == eval(agent_output) else 0.0
except:
accuracy = 0.0
elif task_type == 'induction':
# Induction: LLM์ด ์ƒ์„ฑํ•œ program์œผ๋กœ input์„ ์‹คํ–‰ โ†’ ๊ฒฐ๊ณผ์™€ expected output ๋น„๊ต
input_output_pairs = evaluation_data['input_output_pairs']
agent_code = extracted_answer
accuracies = []
for test_input, expected_output in input_output_pairs:
try:
accuracy = executor.eval_input_prediction(agent_code, expected_output, test_input)
accuracies.append(accuracy if accuracy is not None else 0.0)
except:
accuracies.append(0.0)
# ํ‰๊ท  ์ •ํ™•๋„
accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0
else:
# ๊ธฐ๋ณธ๊ฐ’: ๋ฌธ์ž์—ด ๋งค์นญ
accuracy = 1.0 if expected.strip() == extracted_answer.strip() else 0.0
except Exception as e:
self.logger.log_error(f"Error calculating accuracy for {task_type}: {e}")
accuracy = 0.0
return accuracy
def _extract_answer_by_task_type(self, llm_response: str, task_type: str) -> str:
"""ํƒœ์Šคํฌ ํƒ€์ž…๋ณ„ AZR ๋ฐฉ์‹ ์ •๋‹ต ์ถ”์ถœ"""
if task_type == 'induction':
# <answer> ํƒœ๊ทธ ์ถ”์ถœ (AZR ์ƒˆ ํฌ๋งท)
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
matches = pattern.findall(llm_response)
return matches[-1].strip() if matches else llm_response.strip()
elif task_type == 'abduction':
# <answer> ํƒœ๊ทธ ์ถ”์ถœ (AZR ์ƒˆ ํฌ๋งท)
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
matches = pattern.findall(llm_response)
return matches[-1].strip() if matches else llm_response.strip()
elif task_type == 'deduction':
# <answer> ํƒœ๊ทธ ์ถ”์ถœ (abduction๊ณผ ๋™์ผํ•œ ํฌ๋งท)
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
matches = pattern.findall(llm_response)
return matches[-1].strip() if matches else llm_response.strip()
else:
# ๊ธฐ๋ณธ๊ฐ’: ์ „์ฒด ์‘๋‹ต ๋ฐ˜ํ™˜
return llm_response.strip()
def _compute_rewards_with_azr(self, task_evaluations: Dict[str, List[Dict[str, Any]]],
ipo_triples: List[Dict[str, Any]]) -> Dict[str, Any]:
"""AZR Reward Manager๋กœ ๋ณด์ƒ ๊ณ„์‚ฐ (์‹ค์ œ ์ฝ”๋“œ ์‹คํ–‰ ๊ธฐ๋ฐ˜ ํ‰๊ฐ€)"""
# PythonExecutor ๊ฐ€์ ธ์˜ค๊ธฐ
from ..utils.code_utils.python_executor import PythonExecutor
executor = PythonExecutor()
rewards_by_type = {}
total_rewards = []
for task_type, evaluations in task_evaluations.items():
self.logger.log_info(f"๐ŸŽฏ Computing rewards for {task_type} tasks")
type_rewards = []
for evaluation in evaluations:
expected = evaluation['expected_solution']
llm_response = evaluation['llm_response']
evaluation_data = evaluation['evaluation_data']
# AZR ๋ฐฉ์‹์œผ๋กœ ์ •๋‹ต ์ถ”์ถœ
extracted_answer = self._extract_answer_by_task_type(llm_response, task_type)
# ์‹ค์ œ ์ฝ”๋“œ ์‹คํ–‰ ๊ธฐ๋ฐ˜ ํ‰๊ฐ€ (AZR ๋ฐฉ์‹)
try:
if task_type == 'abduction':
# Abduction: LLM์ด ์˜ˆ์ธกํ•œ input์œผ๋กœ program์„ ์‹คํ–‰ํ•œ ๊ฒฐ๊ณผ์™€ expected output์ด ๊ฐ™์€์ง€ ๋น„๊ต
code = evaluation_data['function_code']
expected_output = evaluation_data['expected_output']
agent_input = extracted_answer
# ํ•จ์ˆ˜ ์ •์˜๋งŒ ์ถ”์ถœ (assert ๋ฌธ ๋“ฑ ์ œ๊ฑฐ)
import re
def extract_function_definition(code):
"""์ฝ”๋“œ์—์„œ import๋ฌธ๊ณผ ํ•จ์ˆ˜ ์ •์˜๋ฅผ ์ถ”์ถœ"""
lines = code.split('\n')
import_lines = []
func_lines = []
in_function = False
base_indent = None
for line in lines:
# import ๋ฌธ ์ˆ˜์ง‘
if line.strip().startswith('from ') or line.strip().startswith('import '):
import_lines.append(line)
# ํ•จ์ˆ˜ ์ •์˜ ์‹œ์ž‘
elif line.strip().startswith('def '):
in_function = True
base_indent = len(line) - len(line.lstrip())
func_lines.append(line)
elif in_function:
# ๋นˆ ์ค„์ด๊ฑฐ๋‚˜ ํ•จ์ˆ˜ ๋‚ด๋ถ€์ธ ๊ฒฝ์šฐ
if line.strip() == '':
func_lines.append(line)
elif line.startswith(' ' * (base_indent + 1)) or line.startswith('\t'):
# ํ•จ์ˆ˜ ๋‚ด๋ถ€ (๋“ค์—ฌ์“ฐ๊ธฐ๊ฐ€ ๋” ๊นŠ์Œ)
func_lines.append(line)
else:
# ํ•จ์ˆ˜ ์™ธ๋ถ€ ์ฝ”๋“œ (assert ๋ฌธ ๋“ฑ) - ์ค‘๋‹จ
break
# import๋ฌธ๊ณผ ํ•จ์ˆ˜๋ฅผ ํ•ฉ์ณ์„œ ๋ฐ˜ํ™˜
if import_lines:
return '\n'.join(import_lines) + '\n\n' + '\n'.join(func_lines)
else:
return '\n'.join(func_lines)
# ํ•จ์ˆ˜ ์ •์˜๋งŒ ์ถ”์ถœ
code = extract_function_definition(code)
# AZR ๋ฐฉ์‹: ํ•จ์ˆ˜๋ช…์„ f๋กœ ํ†ต์ผ (process_code_reasoning_data.py:34 ์ฐธ์กฐ)
# ํ•จ์ˆ˜๋ช… ์ถ”์ถœ
func_name_match = re.search(r'def\s+(\w+)\s*\(', code)
if func_name_match:
original_func_name = func_name_match.group(1)
# ํ•จ์ˆ˜๋ช…์„ f๋กœ ๋ณ€๊ฒฝ (AZR ๋ฐฉ์‹)
code = re.sub(r'def\s+' + re.escape(original_func_name) + r'\s*\(', 'def f(', code)
# expected_output์„ ์‹ค์ œ ๊ฐ’์œผ๋กœ ๋ณ€ํ™˜
try:
expected_output_value = eval(expected_output)
except:
expected_output_value = expected_output
# AZR ๋ฐฉ์‹: EVAL_INPUT_PREDICTION_TEMPLATE ์‚ฌ์šฉ
try:
from ..utils.code_utils.templates import EVAL_INPUT_PREDICTION_TEMPLATE
code_snippet = EVAL_INPUT_PREDICTION_TEMPLATE.format(
code=code,
gold_output=expected_output_value,
agent_input=agent_input
)
result, status = executor.apply(code_snippet)
if 'error' in status.lower():
accuracy = 0.0
else:
# ์‹คํ–‰ ๊ฒฐ๊ณผ์™€ expected output ๋น„๊ต
try:
# AZR ๋ฐฉ์‹: ๊ฒฐ๊ณผ๋Š” Boolean ๊ฐ’ (gold_output == f(agent_input))
if isinstance(result, bool):
# result๊ฐ€ ์ด๋ฏธ boolean์ธ ๊ฒฝ์šฐ
agent_output = result
else:
# result๊ฐ€ ๋ฌธ์ž์—ด์ธ ๊ฒฝ์šฐ eval ์‚ฌ์šฉ
agent_output = eval(result)
accuracy = 1.0 if agent_output else 0.0
except:
accuracy = 0.0
except:
accuracy = 0.0
elif task_type == 'deduction':
# Deduction: LLM์ด ์ƒ์„ฑํ•œ output์„ expected output๊ณผ ๋น„๊ต
expected_output = expected
agent_output = extracted_answer
# ๊ฐ„๋‹จํ•œ eval ๋น„๊ต (AZR ๋ฐฉ์‹)
try:
accuracy = 1.0 if eval(expected_output) == eval(agent_output) else 0.0
except:
accuracy = 0.0
elif task_type == 'induction':
# Induction: LLM์ด ์ƒ์„ฑํ•œ program์œผ๋กœ input์„ ์‹คํ–‰ โ†’ ๊ฒฐ๊ณผ์™€ expected output ๋น„๊ต
input_output_pairs = evaluation_data['input_output_pairs']
agent_code = extracted_answer
# ๋ชจ๋“  input-output ์Œ์— ๋Œ€ํ•ด ํ…Œ์ŠคํŠธ
accuracies = []
for test_input, expected_output in input_output_pairs:
try:
accuracy = executor.eval_input_prediction(agent_code, expected_output, test_input)
accuracies.append(accuracy if accuracy is not None else 0.0)
except:
accuracies.append(0.0)
# ํ‰๊ท  ์ •ํ™•๋„
accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0
else:
# ๊ธฐ๋ณธ๊ฐ’: ๋ฌธ์ž์—ด ๋งค์นญ
accuracy = 1.0 if expected.strip() == extracted_answer.strip() else 0.0
except Exception as e:
self.logger.log_error(f"Error in {task_type} evaluation: {e}")
accuracy = 0.0
# ๋ณด์ƒ ์ •๋ณด ์ €์žฅ
reward = {
'task_id': evaluation['task_id'],
'task_type': task_type,
'extracted_answer': extracted_answer,
'expected_solution': expected,
'basic_accuracy': accuracy,
'final_reward': accuracy
}
type_rewards.append(reward)
total_rewards.append(reward['final_reward'])
rewards_by_type[task_type] = type_rewards
# ์ „์ฒด ํ†ต๊ณ„
avg_reward = sum(total_rewards) / len(total_rewards) if total_rewards else 0.0
return {
'rewards_by_type': rewards_by_type,
'total_tasks': len(total_rewards),
'average_reward': avg_reward,
'reward_distribution': {
task_type: sum(r['final_reward'] for r in rewards) / len(rewards) if rewards else 0.0
for task_type, rewards in rewards_by_type.items()
}
}
def _compute_similarity(self, expected: str, actual: str) -> float:
"""๋ฌธ์ž์—ด ์œ ์‚ฌ์„ฑ ๊ณ„์‚ฐ (๊ฐ„๋‹จํ•œ ๋ฐฉ์‹)"""
expected_words = set(expected.lower().split())
actual_words = set(actual.lower().split())
if not expected_words and not actual_words:
return 1.0
if not expected_words or not actual_words:
return 0.0
intersection = expected_words & actual_words
union = expected_words | actual_words
return len(intersection) / len(union) # Jaccard similarity
def _evaluate_baseline_performance(self, problem: Dict[str, Any]) -> Dict[str, Any]:
"""๋ฒ ์ด์Šค๋ผ์ธ ์„ฑ๋Šฅ ์ธก์ • (temperature=0.05๋กœ 5๋ฒˆ ์‹คํ–‰)"""
self.logger.log_info(f"๐Ÿ“Š Evaluating baseline performance for {problem.get('task_id', 'unknown')}")
baseline_results = {
'success': True,
'total_rounds': self.config.baseline_evaluation_rounds,
'solutions': [],
'evaluations': [],
'success_count': 0,
'average_accuracy': 0.0,
'error': None
}
try:
for round_id in range(self.config.baseline_evaluation_rounds):
self.logger.log_info(f" ๐Ÿ”„ Baseline round {round_id + 1}/{self.config.baseline_evaluation_rounds}")
# ๋ฒ ์ด์Šค๋ผ์ธ temperature๋กœ ์†”๋ฃจ์…˜ ์ƒ์„ฑ
solution = self.solution_generator.generate(problem)
# ๊ตฌ๋ฌธ ๊ฒ€์ฆ
is_valid, syntax_error = self.solution_generator.validate_syntax(solution)
solution_result = {
'round_id': round_id,
'solution': solution,
'syntax_valid': is_valid,
'syntax_error': syntax_error,
'evaluation': None
}
# ์ •ํ™•์„ฑ ํ‰๊ฐ€
if is_valid:
evaluation = self.solution_generator.evaluate_solution(problem, solution)
solution_result['evaluation'] = evaluation
if evaluation['correct']:
baseline_results['success_count'] += 1
self.logger.log_info(f" โœ… Round {round_id + 1}: PASSED ({evaluation['passed_tests']}/{evaluation['total_tests']} tests)")
# ๋ฒ ์ด์Šค๋ผ์ธ ์„ฑ๊ณต ์ผ€์ด์Šค ๋กœ๊ทธ
self.logger.log_problem_attempt(problem, solution, True, evaluation)
else:
self.logger.log_info(f" โŒ Round {round_id + 1}: FAILED ({evaluation['passed_tests']}/{evaluation['total_tests']} tests)")
# ๋ฒ ์ด์Šค๋ผ์ธ ์‹คํŒจ ์ผ€์ด์Šค ๋กœ๊ทธ
self.logger.log_problem_attempt(problem, solution, False, evaluation)
else:
self.logger.log_warning(f" โŒ Round {round_id + 1}: Syntax error - {syntax_error}")
# ๊ตฌ๋ฌธ ์˜ค๋ฅ˜ ์ผ€์ด์Šค ๋กœ๊ทธ
syntax_validation = {'syntax_valid': False, 'syntax_error': syntax_error}
self.logger.log_problem_attempt(problem, solution, False, syntax_validation)
baseline_results['solutions'].append(solution_result)
# Batch evaluation ํ˜•์‹์œผ๋กœ ์ƒ์„ธ ๋กœ๊ทธ ์ €์žฅ (๋ชจ๋“  ๋ผ์šด๋“œ)
self._save_batch_evaluation_format(problem, solution_result, attempt_num=round_id + 1)
# ํ‰๊ท  ์ •ํ™•๋„ ๊ณ„์‚ฐ
if baseline_results['success_count'] > 0:
baseline_results['average_accuracy'] = baseline_results['success_count'] / baseline_results['total_rounds']
self.logger.log_info(f" ๐Ÿ“ˆ Baseline performance: {baseline_results['success_count']}/{baseline_results['total_rounds']} success ({baseline_results['average_accuracy']:.3f})")
except Exception as e:
self.logger.log_error(f"โŒ Baseline evaluation failed: {e}")
baseline_results['success'] = False
baseline_results['error'] = str(e)
return baseline_results
def _save_batch_evaluation_format(self, problem: Dict[str, Any], solution_result: Dict[str, Any], attempt_num: int):
"""Batch evaluation๊ณผ ๋™์ผํ•œ ํ˜•์‹์œผ๋กœ ์ƒ์„ธ ๋กœ๊ทธ ์ €์žฅ"""
from ..testtime.prompts import get_prompt
# Current evaluation ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
current_dir = os.path.join(self.logger.log_dir, "current_evaluation")
os.makedirs(current_dir, exist_ok=True)
# attempt ํŒŒ์ผ ์ƒ์„ฑ
attempt_file = os.path.join(current_dir, f"attempt_{attempt_num}.txt")
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
with open(attempt_file, 'w', encoding='utf-8') as f:
# ํ—ค๋” ์ •๋ณด
f.write(f"Current Evaluation - Attempt {attempt_num}\n")
f.write(f"Problem ID: {problem.get('task_id', 'unknown')}\n")
f.write(f"Benchmark: {problem.get('benchmark_name', 'unknown')}\n")
f.write(f"Generated: {timestamp}\n")
f.write("="*80 + "\n\n")
# 1. ์›๋ณธ ๋ฌธ์ œ
f.write("1. ORIGINAL PROBLEM:\n")
f.write("="*80 + "\n")
f.write(problem.get('prompt', 'No prompt available'))
f.write("\n" + "="*80 + "\n\n")
# 2. LLM์— ๋“ค์–ด๊ฐ€๋Š” ์Šคํฌ๋ฆฝํŠธ (ํ”„๋กฌํ”„ํŠธ)
f.write("2. LLM INPUT SCRIPT (PROMPT):\n")
f.write("="*80 + "\n")
problem_prompt = problem.get('prompt', '')
# ์ค‘์•™ ํ”„๋กฌํ”„ํŠธ ์‹œ์Šคํ…œ ์‚ฌ์šฉ
try:
if 'HumanEval' in problem.get('task_id', ''):
full_prompt = get_prompt("solution_humaneval_basic",
problem_prompt=problem_prompt)
else:
full_prompt = get_prompt("solution_mbpp_basic",
problem_prompt=problem_prompt)
f.write(full_prompt.strip())
except Exception as e:
# ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ ์‹คํŒจ ์‹œ ๊ธฐ๋ณธ ํ˜•์‹ ์‚ฌ์šฉ
f.write(f"You are a Python writing assistant. Complete the following Python function.\n\n{problem_prompt}\n\nPlease provide a complete implementation of the function.")
f.write("\n" + "="*80 + "\n\n")
# 3. LLM์˜ ์‘๋‹ต
f.write("3. LLM RESPONSE:\n")
f.write("="*80 + "\n")
f.write(solution_result.get('solution', 'No solution generated'))
f.write("\n" + "="*80 + "\n\n")
# 4. ์ •๋‹ต ์—ฌ๋ถ€
f.write("4. CORRECTNESS EVALUATION:\n")
f.write("="*80 + "\n")
# ๊ตฌ๋ฌธ ๊ฒ€์ฆ
f.write(f"Syntax Valid: {'โœ… YES' if solution_result.get('syntax_valid', False) else 'โŒ NO'}\n")
if solution_result.get('syntax_error'):
f.write(f"Syntax Error: {solution_result['syntax_error']}\n")
# ์ •ํ™•์„ฑ ํ‰๊ฐ€
evaluation = solution_result.get('evaluation')
if evaluation:
if evaluation.get('correct', False):
f.write(f"Result: โœ… CORRECT ({evaluation.get('passed_tests', 0)}/{evaluation.get('total_tests', 0)} tests passed)\n")
else:
f.write(f"Result: โŒ INCORRECT ({evaluation.get('passed_tests', 0)}/{evaluation.get('total_tests', 0)} tests passed)\n")
if evaluation.get('error'):
f.write(f"Evaluation Error: {evaluation['error']}\n")
else:
f.write("Result: โŒ NO EVALUATION (syntax error or evaluation failed)\n")
f.write("="*80 + "\n")
self.logger.log_info(f"๐Ÿ“ Batch evaluation format saved: {attempt_file}")
def _save_llm_responses(self, task_type: str, evaluations: List[Dict[str, Any]]):
"""LLM ์‘๋‹ต์„ llm_responses ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅ (batch evaluation๊ณผ ๋™์ผํ•œ ํ˜•์‹)"""
try:
# LLM responses ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
llm_dir = os.path.join(self.logger.log_dir, "llm_responses")
os.makedirs(llm_dir, exist_ok=True)
# ๊ฐ task๋ณ„๋กœ ๊ฐœ๋ณ„ ํŒŒ์ผ ์ƒ์„ฑ (batch evaluation๊ณผ ๋™์ผ)
for i, evaluation in enumerate(evaluations, 1):
problem_id = evaluation['task_id'].split('_')[0] if '_' in evaluation['task_id'] else evaluation['task_id']
response_file = os.path.join(llm_dir, f"{problem_id}_{task_type}_{i}_response.txt")
with open(response_file, 'w', encoding='utf-8') as f:
f.write(f"Task Type: {task_type}\n")
f.write(f"Task ID: {evaluation['task_id']}\n")
f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n")
f.write("="*80 + "\nORIGINAL PROMPT:\n")
f.write("="*80 + "\n")
f.write(evaluation['prompt'])
f.write("\n" + "="*80 + "\n")
f.write("LLM RESPONSE:\n")
f.write("="*80 + "\n")
f.write(evaluation['llm_response'])
f.write("\n" + "="*80 + "\n")
f.write("EXPECTED SOLUTION:\n")
f.write("="*80 + "\n")
f.write(str(evaluation['expected_solution']))
f.write("\n" + "="*80 + "\n")
f.write("EXTRACTED ANSWER:\n")
f.write("="*80 + "\n")
# AZR ๋ฐฉ์‹์œผ๋กœ ๋‹ต์•ˆ ์ถ”์ถœ
extracted_answer = self._extract_answer_by_task_type(evaluation['llm_response'], task_type)
f.write(extracted_answer)
f.write("\n" + "="*80 + "\n")
f.write("MATCH RESULT:\n")
f.write("="*80 + "\n")
accuracy = evaluation.get('basic_accuracy', 0.0)
if accuracy > 0.5:
f.write(f"โœ… CORRECT (Score: {accuracy:.3f})")
else:
f.write(f"โŒ INCORRECT (Score: {accuracy:.3f})")
self.logger.log_info(f"๐Ÿ“ LLM responses saved to {llm_dir} (batch evaluation format)")
except Exception as e:
self.logger.log_warning(f"Failed to save LLM responses: {e}")
def _save_task_summary_json(self, problem: Dict[str, Any], baseline_results: Dict[str, Any],
task_evaluations: Dict[str, List[Dict[str, Any]]], round_num: int = None):
"""batch evaluation๊ณผ ๋™์ผํ•œ ํ˜•์‹์˜ summary.json ์ƒ์„ฑ"""
try:
problem_id = problem.get('task_id', 'unknown')
problem_id_safe = problem_id.replace('/', '_')
# Round๋ณ„ summary์™€ ์ „์ฒด summary ๋ชจ๋‘ ์ƒ์„ฑ
if round_num is not None:
# Round๋ณ„ summary ํŒŒ์ผ (ํ˜„์žฌ ๋ผ์šด๋“œ ๋””๋ ‰ํ† ๋ฆฌ์—)
round_summary_file = os.path.join(self.logger.log_dir, f"{problem_id_safe}_round_{round_num}_summary.json")
self._save_single_summary(problem, baseline_results, task_evaluations, round_summary_file, round_num)
# ์ „์ฒด summary ํŒŒ์ผ (problem ๋ ˆ๋ฒจ์—)
summary_file = os.path.join(self.logger.log_dir.parent, f"{problem_id_safe}_summary.json")
self._save_single_summary(problem, baseline_results, task_evaluations, summary_file)
except Exception as e:
self.logger.log_warning(f"Failed to save task summary: {e}")
def _save_single_summary(self, problem: Dict[str, Any], baseline_results: Dict[str, Any],
task_evaluations: Dict[str, List[Dict[str, Any]]], summary_file: str, round_num: int = None):
"""๋‹จ์ผ summary ํŒŒ์ผ ์ €์žฅ"""
with open(summary_file, 'w', encoding='utf-8') as f:
problem_id = problem.get('task_id', 'unknown')
summary = {
'problem_id': problem_id,
'benchmark': problem.get('benchmark_name', 'unknown'),
'success': True,
'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S'),
'initial_solution_correct': False,
'ipo_extraction_success': True, # IPO ์ถ”์ถœ์€ ํ•ญ์ƒ ์„ฑ๊ณตํ•œ๋‹ค๊ณ  ๊ฐ€์ •
'reasoning_task_results': {}
}
# Round ์ •๋ณด ์ถ”๊ฐ€ (๋ผ์šด๋“œ๋ณ„ summary์ธ ๊ฒฝ์šฐ)
if round_num is not None:
summary['round'] = round_num
# ์ดˆ๊ธฐ ์†”๋ฃจ์…˜ ๊ฒฐ๊ณผ (baseline์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ)
if baseline_results.get('success_count', 0) > 0:
summary['initial_solution_correct'] = True
# Reasoning task ๊ฒฐ๊ณผ (batch evaluation๊ณผ ๋™์ผํ•œ ํ˜•์‹)
for task_type, evaluations in task_evaluations.items():
if evaluations:
correct_count = sum(1 for eval_data in evaluations if eval_data.get('basic_accuracy', 0) > 0.5)
total_count = len(evaluations)
summary['reasoning_task_results'][task_type] = {
'correct': correct_count,
'total': total_count,
'accuracy': correct_count / total_count if total_count > 0 else 0
}
json.dump(summary, f, indent=2, ensure_ascii=False)
self.logger.log_info(f"๐Ÿ“‹ Summary saved: {summary_file}")
def _generate_diverse_programs_and_ipo(self, problem: Dict[str, Any]) -> Dict[str, Any]:
"""๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ ๋ฐ IPO ์ถ”์ถœ (VLLM ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ)"""
# VLLM ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ ์‚ฌ์šฉ
return self._generate_programs_batch_vllm(problem)
def _generate_programs_batch_vllm(self, problem: Dict[str, Any]) -> Dict[str, Any]:
"""VLLM ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ๋กœ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ"""
problem_id = problem.get('task_id', 'unknown')
batch_size = getattr(self.config, 'parallel_batch_size', 4)
self.logger.log_info(f"๐ŸŽจ Generating {self.config.num_program_variations} diverse programs for {problem_id} (BATCH)")
self.logger.log_info(f"๐Ÿ“Š Using batch size: {batch_size} (concurrent prompts)")
diverse_results = {
'success': False,
'total_programs': self.config.num_program_variations,
'valid_programs': 0,
'programs': [],
'total_ipo_triples': 0,
'error': None,
'batch_processing': True,
'batch_size': batch_size
}
try:
# ๋ฐฐ์น˜๋ณ„๋กœ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ
all_programs = []
for batch_idx, batch_start in enumerate(range(0, self.config.num_program_variations, batch_size)):
batch_end = min(batch_start + batch_size, self.config.num_program_variations)
batch_ids = list(range(batch_start, batch_end))
self.logger.log_info(f" ๐ŸŽฏ Processing batch {batch_idx + 1}: programs {batch_start}-{batch_end-1}")
# ๋ฐฐ์น˜์šฉ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
batch_prompts = []
for variation_id in batch_ids:
prompt = self._create_diverse_generation_prompt(problem, variation_id)
batch_prompts.append(prompt)
# VLLM ๋ฐฐ์น˜ ์ถ”๋ก 
batch_solutions = self.solution_generator.generate_batch(
batch_prompts,
temperature=self.config.diverse_generation_temperature
)
self.logger.log_info(f" ๐Ÿ“Š Generated {len(batch_solutions)} solutions")
# ๋ฐฐ์น˜ ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ - ๋จผ์ € IPO ์ถ”์ถœ
batch_program_results = []
for i, (variation_id, solution) in enumerate(zip(batch_ids, batch_solutions)):
program_result = self._process_single_program_basic(problem, solution, variation_id)
batch_program_results.append(program_result)
# ์„ฑ๊ณตํ•œ ํ”„๋กœ๊ทธ๋žจ๋“ค์— ๋Œ€ํ•ด input generation์„ ๋ฐฐ์น˜๋กœ ์ฒ˜๋ฆฌ
successful_programs = [p for p in batch_program_results if p.get('success', False)]
self.logger.log_info(f" ๐Ÿ“Š Batch results: {len(batch_program_results)} programs, {len(successful_programs)} successful")
for i, prog in enumerate(batch_program_results):
self.logger.log_info(f" Program {i}: success={prog.get('success')}, IPO triples={prog.get('num_ipo_triples', 0)}")
if successful_programs:
self.logger.log_info(f" ๐ŸŽฒ Generating inputs for {len(successful_programs)} valid programs (BATCH)")
# Input generation์„ ์œ„ํ•œ ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ ์ค€๋น„
input_generation_pairs = []
for program_result in successful_programs:
for round_num in range(getattr(self.config, 'input_generation_rounds', 3)):
# ํ˜„์žฌ๊นŒ์ง€์˜ ๋ชจ๋“  ์˜ˆ์ œ ์ˆ˜์ง‘
existing_examples = [(triple['full_input_str'], triple['actual_output'])
for triple in program_result['ipo_triples']]
# ์ด๋ฏธ ์ƒ์„ฑ๋œ ์ž…๋ ฅ๋„ ๊ธฐ์กด ์˜ˆ์ œ์— ํฌํ•จ
for prev_input in program_result.get('all_generated_inputs', []):
if 'input_args' in prev_input and 'expected_output' in prev_input:
existing_examples.append((
str(prev_input['input_args']),
str(prev_input['expected_output'])
))
input_generation_pairs.append({
'problem': problem,
'solution': program_result['extracted_function_code'],
'existing_examples': existing_examples,
'program_result': program_result,
'round_num': round_num
})
# ๋ฐฐ์น˜๋กœ input generation ์‹คํ–‰
if input_generation_pairs:
self.logger.log_info(f" ๐Ÿ“Š Total input generation pairs: {len(input_generation_pairs)}")
batch_input_results, batch_generation_info = self.ipo_extractor.generate_diverse_inputs_batch(input_generation_pairs)
self.logger.log_info(f" ๐Ÿ“Š Batch input results: {len(batch_input_results)} responses")
# ๊ฒฐ๊ณผ๋ฅผ ํ”„๋กœ๊ทธ๋žจ๋ณ„๋กœ ๋‹ค์‹œ ์ •๋ฆฌ
pair_idx = 0
for program_result in successful_programs:
program_result['all_generated_inputs'] = []
program_result['input_generation_info'] = [] # Store generation info for each round
input_generation_rounds = getattr(self.config, 'input_generation_rounds', 3)
for round_num in range(input_generation_rounds):
if pair_idx < len(batch_input_results):
round_inputs = batch_input_results[pair_idx]
program_result['all_generated_inputs'].extend(round_inputs)
# Store generation info for this round
if pair_idx < len(batch_generation_info) and batch_generation_info[pair_idx]:
program_result['input_generation_info'].append(batch_generation_info[pair_idx])
# ์ƒˆ๋กœ์šด ์ž…๋ ฅ์œผ๋กœ IPO triple ์ƒ์„ฑ
for new_input in round_inputs:
new_triple = self.ipo_extractor.create_ipo_from_input(
problem, program_result['extracted_function_code'], new_input
)
if new_triple:
program_result['ipo_triples'].append(new_triple)
pair_idx += 1
# ์ตœ์ข… ํ†ต๊ณ„ ์—…๋ฐ์ดํŠธ
program_result['num_generated_inputs'] = len(program_result['all_generated_inputs'])
program_result['num_ipo_triples'] = len(program_result['ipo_triples'])
program_result['input_generation_rounds'] = input_generation_rounds
# ์ „์ฒด ๊ฒฐ๊ณผ์— ์ถ”๊ฐ€
for prog_idx, program_result in enumerate(batch_program_results):
all_programs.append(program_result)
if program_result.get('success', False):
diverse_results['valid_programs'] += 1
diverse_results['total_ipo_triples'] += program_result.get('num_ipo_triples', 0)
# IPO triples๋ฅผ buffer์— ์ถ”๊ฐ€ (Sequential ๋ชจ๋“œ์™€ ๋™์ผํ•˜๊ฒŒ)
program_id = f'program_{batch_idx * batch_size + prog_idx}'
for ipo_idx, triple in enumerate(program_result.get('ipo_triples', [])):
# IPO triple์— ๋งคํ•‘ ์ •๋ณด ์ถ”๊ฐ€
triple['source_program_id'] = program_id
triple['ipo_index'] = ipo_idx
self.ipo_buffer.add(problem_id, triple)
if program_result.get('ipo_triples'):
self.logger.log_info(f" ๐Ÿ“ฅ Added {len(program_result['ipo_triples'])} IPO triples to buffer from {program_id}")
diverse_results['programs'] = all_programs
diverse_results['success'] = diverse_results['valid_programs'] > 0
self.logger.log_info(f"โœ… Batch processing completed:")
self.logger.log_info(f" - Valid programs: {diverse_results['valid_programs']}/{diverse_results['total_programs']}")
self.logger.log_info(f" - Total IPO triples: {diverse_results['total_ipo_triples']}")
# ํ”„๋กœ๊ทธ๋žจ๋ณ„ ๋””๋ ‰ํ† ๋ฆฌ ๊ตฌ์กฐ ์ €์žฅ (๋ฐฐ์น˜ ๋ชจ๋“œ์—์„œ๋„ ๋™์ผํ•œ ํ•จ์ˆ˜ ์‚ฌ์šฉ)
self._save_diverse_programs_evaluation(problem, diverse_results)
except Exception as e:
self.logger.log_error(f"Batch processing failed: {e}")
diverse_results['error'] = str(e)
# ์‹คํŒจ ์‹œ ์ˆœ์ฐจ ์ฒ˜๋ฆฌ๋กœ fallback
self.logger.log_info("๐Ÿ”„ Falling back to sequential processing")
return self._generate_programs_sequential(problem)
return diverse_results
def _create_diverse_generation_prompt(self, problem: Dict[str, Any], variation_id: int) -> str:
"""๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ์šฉ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
# solution_generator์˜ ๊ธฐ์กด ๋กœ์ง ํ™œ์šฉ
problem_description = problem.get('prompt', '')
if not problem_description:
problem_description = problem.get('description', '')
# ๋‹ค์–‘์„ฑ์„ ์œ„ํ•œ ํ”„๋กฌํ”„ํŠธ
diversity_prompts = [
"Please generate a complete, self-contained Python script that solves the following problem.",
"Write a Python solution for this problem using a different approach.",
"Create an alternative Python implementation for the given problem.",
"Solve this problem with a unique Python solution approach."
]
base_prompt = diversity_prompts[variation_id % len(diversity_prompts)]
prompt = f"""{base_prompt}
Problem statement:
\"\"\"
{problem_description}
\"\"\"
Please provide a complete solution with proper function implementation."""
return prompt
def _process_single_program(self, problem: Dict[str, Any], solution: str, variation_id: int) -> Dict[str, Any]:
"""๋‹จ์ผ ํ”„๋กœ๊ทธ๋žจ ์ฒ˜๋ฆฌ (๊ฒ€์ฆ + IPO ์ถ”์ถœ)"""
# ๊ตฌ๋ฌธ ๊ฒ€์ฆ
is_valid, syntax_error = self.solution_generator.validate_syntax(solution)
program_result = {
'variation_id': variation_id,
'solution': solution,
'syntax_valid': is_valid,
'syntax_error': syntax_error,
'ipo_triples': [],
'num_ipo_triples': 0,
'generated_inputs': [],
'num_generated_inputs': 0,
'input_generation_rounds': 0,
'success': False
}
if is_valid:
try:
# IPO ์ถ”์ถœ
extracted_function_code = self.solution_generator._extract_function_code(solution)
ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code)
if ipo_triples:
program_result['ipo_triples'] = ipo_triples
program_result['num_ipo_triples'] = len(ipo_triples)
# ๋‹ค์ค‘ ๋ผ์šด๋“œ Input ์ฆ๊ฐ•
all_generated_inputs = []
input_generation_rounds = getattr(self.config, 'input_generation_rounds', 3)
for round_num in range(input_generation_rounds):
# ํ˜„์žฌ๊นŒ์ง€์˜ ๋ชจ๋“  ์˜ˆ์ œ ์ˆ˜์ง‘
existing_examples = [(triple['full_input_str'], triple['actual_output'])
for triple in ipo_triples]
# ์ด๋ฏธ ์ƒ์„ฑ๋œ ์ž…๋ ฅ๋„ ๊ธฐ์กด ์˜ˆ์ œ์— ํฌํ•จ
for prev_input in all_generated_inputs:
if 'input_args' in prev_input and 'expected_output' in prev_input:
existing_examples.append((
str(prev_input['input_args']),
str(prev_input['expected_output'])
))
# ์ƒˆ๋กœ์šด diverse input ์ƒ์„ฑ
diverse_inputs = self.ipo_extractor.generate_diverse_inputs(
problem, extracted_function_code, existing_examples
)
if diverse_inputs:
# ์ƒˆ๋กœ์šด ์ž…๋ ฅ์œผ๋กœ ์ถ”๊ฐ€ IPO ์ƒ์„ฑ
for new_input in diverse_inputs:
new_triple = self.ipo_extractor.create_ipo_from_input(
problem, extracted_function_code, new_input
)
if new_triple:
ipo_triples.append(new_triple)
all_generated_inputs.extend(diverse_inputs)
program_result['generated_inputs'] = all_generated_inputs
program_result['num_generated_inputs'] = len(all_generated_inputs)
program_result['ipo_triples'] = ipo_triples # ์—…๋ฐ์ดํŠธ๋œ triple ๋ชฉ๋ก
program_result['num_ipo_triples'] = len(ipo_triples)
program_result['input_generation_rounds'] = input_generation_rounds
program_result['success'] = True
except Exception as e:
program_result['error'] = str(e)
self.logger.log_error(f"IPO extraction failed for variation {variation_id}: {e}")
return program_result
def _process_single_program_basic(self, problem: Dict[str, Any], solution: str, variation_id: int) -> Dict[str, Any]:
"""๋‹จ์ผ ํ”„๋กœ๊ทธ๋žจ ๊ธฐ๋ณธ ์ฒ˜๋ฆฌ (IPO ์ถ”์ถœ๋งŒ, input generation ์ œ์™ธ)"""
# ๊ตฌ๋ฌธ ๊ฒ€์ฆ
is_valid, syntax_error = self.solution_generator.validate_syntax(solution)
program_result = {
'variation_id': variation_id,
'solution': solution,
'syntax_valid': is_valid,
'syntax_error': syntax_error,
'ipo_triples': [],
'num_ipo_triples': 0,
'all_generated_inputs': [],
'num_generated_inputs': 0,
'input_generation_rounds': 0,
'extracted_function_code': None,
'success': False
}
if is_valid:
try:
# IPO ์ถ”์ถœ
extracted_function_code = self.solution_generator._extract_function_code(solution)
program_result['extracted_function_code'] = extracted_function_code
ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code)
if ipo_triples:
program_result['ipo_triples'] = ipo_triples
program_result['num_ipo_triples'] = len(ipo_triples)
program_result['success'] = True
except Exception as e:
program_result['error'] = str(e)
self.logger.log_error(f"IPO extraction failed for variation {variation_id}: {e}")
return program_result
def _generate_programs_sequential(self, problem: Dict[str, Any]) -> Dict[str, Any]:
"""๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ ๋ฐ ๊ฐ๊ฐ์—์„œ IPO ์ถ”์ถœ"""
self.logger.log_info(f"๐ŸŽจ Generating {self.config.num_program_variations} diverse programs for {problem.get('task_id', 'unknown')}")
diverse_results = {
'success': False,
'total_programs': self.config.num_program_variations,
'valid_programs': 0,
'programs': [],
'total_ipo_triples': 0,
'error': None
}
try:
for variation_id in range(self.config.num_program_variations):
self.logger.log_info(f" ๐ŸŽฏ Generating program variation {variation_id + 1}/{self.config.num_program_variations}")
# ๋‹ค์–‘ํ•œ ์†”๋ฃจ์…˜ ์ƒ์„ฑ (temperature=0.7)
diverse_solution = self.solution_generator.generate_diverse(
problem,
temperature=self.config.diverse_generation_temperature,
variation_id=variation_id
)
# ๊ตฌ๋ฌธ ๊ฒ€์ฆ
is_valid, syntax_error = self.solution_generator.validate_syntax(diverse_solution)
program_result = {
'variation_id': variation_id,
'solution': diverse_solution,
'syntax_valid': is_valid,
'syntax_error': syntax_error,
'ipo_triples': [],
'num_ipo_triples': 0,
'generated_inputs': [],
'num_generated_inputs': 0
}
if is_valid:
diverse_results['valid_programs'] += 1
# IPO ์ถ”์ถœ
extracted_function_code = self.solution_generator._extract_function_code(diverse_solution)
ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code)
if ipo_triples:
program_result['ipo_triples'] = ipo_triples
program_result['num_ipo_triples'] = len(ipo_triples)
# ๋‹ค์ค‘ ๋ผ์šด๋“œ Input ์ฆ๊ฐ•
all_generated_inputs = []
input_generation_rounds = getattr(self.config, 'input_generation_rounds', 3) # ๊ธฐ๋ณธ 3๋ผ์šด๋“œ
for round_num in range(input_generation_rounds):
self.logger.log_info(f" ๐ŸŽฏ Input generation round {round_num + 1}/{input_generation_rounds}")
# ํ˜„์žฌ๊นŒ์ง€์˜ ๋ชจ๋“  ์˜ˆ์ œ ์ˆ˜์ง‘
existing_examples = [(triple['full_input_str'], triple['actual_output'])
for triple in ipo_triples]
# ์ด๋ฏธ ์ƒ์„ฑ๋œ ์ž…๋ ฅ๋„ ๊ธฐ์กด ์˜ˆ์ œ์— ํฌํ•จ
for prev_input in all_generated_inputs:
if 'input_args' in prev_input and 'expected_output' in prev_input:
existing_examples.append((
str(prev_input['input_args']),
str(prev_input['expected_output'])
))
# ์ƒˆ๋กœ์šด diverse input ์ƒ์„ฑ
diverse_inputs = self.ipo_extractor.generate_diverse_inputs(
problem, extracted_function_code, existing_examples
)
if not diverse_inputs:
self.logger.log_warning(f" โš ๏ธ Round {round_num + 1}: No valid inputs generated")
continue
self.logger.log_info(f" โœ… Round {round_num + 1}: Generated {len(diverse_inputs)} new inputs")
# Input generation ์ •๋ณด ์ €์žฅ (์ฒซ ๋ฒˆ์งธ ๋ผ์šด๋“œ๋งŒ)
if round_num == 0 and hasattr(self.ipo_extractor, 'last_input_generation_info'):
self._save_input_generation_details(
problem,
variation_id + 1,
self.ipo_extractor.last_input_generation_info
)
# ์ƒˆ๋กœ์šด ์ž…๋ ฅ์œผ๋กœ ์ถ”๊ฐ€ IPO ์ƒ์„ฑ
round_ipo_count = 0
for new_input in diverse_inputs:
new_triple = self.ipo_extractor.create_ipo_from_input(
problem, extracted_function_code, new_input
)
if new_triple:
ipo_triples.append(new_triple)
round_ipo_count += 1
self.logger.log_info(f" ๐Ÿ“Š Round {round_num + 1}: Created {round_ipo_count} IPO triples")
all_generated_inputs.extend(diverse_inputs)
program_result['generated_inputs'] = all_generated_inputs
program_result['input_generation_rounds'] = input_generation_rounds
program_result['num_generated_inputs'] = len(diverse_inputs)
program_result['num_ipo_triples'] = len(ipo_triples)
# Input generation ์ •๋ณด ์ถ”๊ฐ€
if hasattr(self.ipo_extractor, 'last_input_generation_info'):
program_result['input_generation_info'] = self.ipo_extractor.last_input_generation_info
# Buffer์— ์ €์žฅ (source_program_id ์ถ”๊ฐ€)
problem_id = problem.get('task_id', 'unknown')
program_id = f'program_{variation_id}'
for ipo_idx, triple in enumerate(ipo_triples):
# ๐Ÿ†• IPO triple์— ๋งคํ•‘ ์ •๋ณด ์ถ”๊ฐ€
triple['source_program_id'] = program_id
triple['ipo_index'] = ipo_idx
self.ipo_buffer.add(problem_id, triple)
diverse_results['total_ipo_triples'] += len(ipo_triples)
self.logger.log_info(f" โœ… Program {variation_id + 1}: {len(ipo_triples)} IPO triples generated")
else:
self.logger.log_warning(f" โš ๏ธ Program {variation_id + 1}: No IPO triples extracted")
else:
self.logger.log_warning(f" โŒ Program {variation_id + 1}: Syntax error - {syntax_error}")
diverse_results['programs'].append(program_result)
# ์„ฑ๊ณต ํŒ์ •: ์ตœ์†Œ 1๊ฐœ ์ด์ƒ์˜ ์œ ํšจํ•œ ํ”„๋กœ๊ทธ๋žจ์ด ์žˆ์–ด์•ผ ํ•จ
diverse_results['success'] = diverse_results['valid_programs'] > 0
self.logger.log_info(f" ๐Ÿ“Š Diverse programs: {diverse_results['valid_programs']}/{diverse_results['total_programs']} valid, {diverse_results['total_ipo_triples']} total IPO triples")
except Exception as e:
self.logger.log_error(f"โŒ Diverse program generation failed: {e}")
diverse_results['error'] = str(e)
return diverse_results
def _save_input_generation_details(self, problem: Dict[str, Any], program_id: int,
input_gen_info: Dict[str, Any]) -> None:
"""Input generation ์ƒ์„ธ ์ •๋ณด๋ฅผ ํŒŒ์ผ๋กœ ์ €์žฅ"""
try:
problem_id = problem.get('task_id', 'unknown')
# input_generation_info๋ฅผ IPO extractor์—์„œ ์ €์žฅํ•ด๋‘” ๊ฒƒ์„ ์‚ฌ์šฉ
# ์ด ์ •๋ณด๋Š” ๋‚˜์ค‘์— batch_evaluate_testtime.py์—์„œ ์ €์žฅ๋  ๊ฒƒ์ž„
# ์—ฌ๊ธฐ์„œ๋Š” ๋ฉ”๋ชจ๋ฆฌ์—๋งŒ ๋ณด๊ด€
if not hasattr(self, '_input_generation_infos'):
self._input_generation_infos = {}
key = f"{problem_id}_program_{program_id}"
self._input_generation_infos[key] = input_gen_info
self.logger.log_info(f"Input generation info collected for {problem_id} program {program_id}")
return
# ํ”„๋กœ๊ทธ๋žจ๋ณ„ ์ƒ์„ธ ํŒŒ์ผ ์ƒ์„ฑ
detail_file = input_gen_dir / f"program_{program_id}_details.txt"
with open(detail_file, 'w', encoding='utf-8') as f:
f.write(f"Input Generation Details\n")
f.write(f"Problem ID: {problem_id}\n")
f.write(f"Program ID: {program_id}\n")
f.write(f"Generated: {self.timestamp}\n")
f.write("=" * 80 + "\n\n")
f.write("1. FUNCTION INFO:\n")
f.write("=" * 80 + "\n")
func_info = input_gen_info.get('function_info', {})
f.write(f"Function Name: {func_info.get('name', 'N/A')}\n")
f.write(f"Parameters: {func_info.get('args', 'N/A')}\n")
f.write(f"Return Type: {func_info.get('return_type', 'N/A')}\n\n")
f.write("2. ARGUMENT TYPE INFO:\n")
f.write("=" * 80 + "\n")
f.write(input_gen_info.get('arg_type_info', 'N/A') + "\n\n")
f.write("3. EXISTING EXAMPLES:\n")
f.write("=" * 80 + "\n")
for i, (inp, out) in enumerate(input_gen_info.get('existing_examples', [])):
f.write(f"Example {i+1}: Input: {inp} โ†’ Output: {out}\n")
f.write("\n")
f.write("4. LLM PROMPT:\n")
f.write("=" * 80 + "\n")
f.write(input_gen_info.get('prompt', 'N/A') + "\n\n")
f.write("5. LLM RESPONSE:\n")
f.write("=" * 80 + "\n")
f.write(input_gen_info.get('llm_response', 'N/A') + "\n\n")
f.write("6. EXTRACTED INPUTS:\n")
f.write("=" * 80 + "\n")
extracted = input_gen_info.get('extracted_inputs', [])
if extracted:
for i, inp_data in enumerate(extracted):
f.write(f"Input {i+1}: {inp_data}\n")
else:
f.write("No inputs extracted\n")
# ์ „์ฒด ์š”์•ฝ ํŒŒ์ผ ์—…๋ฐ์ดํŠธ
summary_file = input_gen_dir / "input_generation_summary.txt"
mode = 'a' if summary_file.exists() else 'w'
with open(summary_file, mode, encoding='utf-8') as f:
if mode == 'w':
f.write(f"Input Generation Summary\n")
f.write(f"Problem ID: {problem_id}\n")
f.write(f"Generated: {self.timestamp}\n")
f.write("=" * 80 + "\n\n")
f.write(f"Program {program_id}: {len(extracted)} inputs generated\n")
except Exception as e:
self.logger.log_warning(f"Failed to save input generation details: {e}")
def _save_diverse_programs_evaluation(self, problem: Dict[str, Any],
diverse_results: Dict[str, Any]) -> None:
"""๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ๋“ค์˜ ํ‰๊ฐ€ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅ (batch evaluation๊ณผ ๋™์ผํ•œ ํ˜•์‹)"""
try:
problem_id = problem.get('task_id', 'unknown')
# Diverse programs ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
diverse_dir = os.path.join(self.logger.log_dir, "diverse_programs")
os.makedirs(diverse_dir, exist_ok=True)
# ์š”์•ฝ ํŒŒ์ผ ์ƒ์„ฑ (batch evaluation๊ณผ ๋™์ผ)
summary_file = os.path.join(diverse_dir, "diverse_summary.txt")
with open(summary_file, 'w', encoding='utf-8') as f:
f.write(f"Diverse Programs Summary\n")
f.write(f"Problem ID: {problem_id}\n")
f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n")
f.write(f"Total Programs: {len(diverse_results.get('programs', []))}\n")
f.write("="*50 + "\n\n")
for i, program in enumerate(diverse_results.get('programs', []), 1):
f.write(f"Program {i}: {'โœ… Valid' if program.get('syntax_valid', False) else 'โŒ Invalid'}\n")
f.write(f"IPO Triples: {program.get('num_ipo_triples', 0)}\n")
f.write(f"Generated Inputs: {program.get('num_generated_inputs', 0)}\n\n")
# ๊ฐ ํ”„๋กœ๊ทธ๋žจ๋ณ„ ๋””๋ ‰ํ† ๋ฆฌ์™€ ํŒŒ์ผ ์ƒ์„ฑ (batch evaluation๊ณผ ๋™์ผํ•œ ๊ตฌ์กฐ)
for i, program in enumerate(diverse_results.get('programs', []), 1):
program_dir = os.path.join(diverse_dir, f"program_{i}")
os.makedirs(program_dir, exist_ok=True)
# 1. generation_details.txt (batch evaluation๊ณผ ๋™์ผํ•œ ํ˜•์‹)
details_file = os.path.join(program_dir, "generation_details.txt")
with open(details_file, 'w', encoding='utf-8') as f:
f.write(f"Diverse Program {i} - Generation Details\n")
f.write(f"Problem ID: {problem_id}\n")
f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n")
f.write("="*80 + "\n\n")
f.write("1. ORIGINAL PROBLEM:\n")
f.write("="*80 + "\n")
f.write(problem.get('prompt', 'No prompt available') + "\n")
f.write("="*80 + "\n\n")
f.write("2. DIVERSITY PROMPT USED:\n")
f.write("="*80 + "\n")
f.write(program.get('diversity_instruction', 'Standard generation') + "\n")
f.write("="*80 + "\n\n")
f.write("3. LLM RESPONSE:\n")
f.write("="*80 + "\n")
f.write(program.get('solution', 'N/A') + "\n")
f.write("="*80 + "\n\n")
f.write("4. EVALUATION RESULTS:\n")
f.write("="*80 + "\n")
f.write(f"Syntax Valid: {'โœ… YES' if program.get('syntax_valid', False) else 'โŒ NO'}\n")
f.write(f"IPO Triples Generated: {program.get('num_ipo_triples', 0)}\n")
f.write(f"Input Generation: {program.get('num_generated_inputs', 0)} new inputs\n")
f.write("="*80 + "\n")
# 2. solution.py
solution_file = os.path.join(program_dir, "solution.py")
with open(solution_file, 'w', encoding='utf-8') as f:
f.write(f"# Diverse Program {i}\n")
f.write(f"# Problem ID: {problem_id}\n")
f.write(f"# Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n")
f.write(f"# Syntax Valid: {program.get('syntax_valid', False)}\n")
f.write(f"# IPO Triples: {program.get('num_ipo_triples', 0)}\n\n")
# ์ถ”์ถœ๋œ ํ•จ์ˆ˜ ์ฝ”๋“œ๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉ, ์—†์œผ๋ฉด ์›๋ณธ ์†”๋ฃจ์…˜ ์‚ฌ์šฉ
# ์ด๋ฏธ generate_batch์—์„œ ํ›„์ฒ˜๋ฆฌ๊ฐ€ ๋˜์—ˆ์œผ๋ฏ€๋กœ solution์„ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
f.write(program.get('solution', '# No solution available'))
# 3. ipo_triples ๋””๋ ‰ํ† ๋ฆฌ์™€ ํŒŒ์ผ๋“ค
if program.get('ipo_triples'):
ipo_dir = os.path.join(program_dir, "ipo_triples")
os.makedirs(ipo_dir, exist_ok=True)
for j, triple in enumerate(program['ipo_triples'], 1):
triple_file = os.path.join(ipo_dir, f"triple_{j}.json")
with open(triple_file, 'w', encoding='utf-8') as f:
json.dump(triple, f, indent=2)
# 4. input_generation_details.txt ํŒŒ์ผ ์ƒ์„ฑ
if program.get('input_generation_info'):
details_file = os.path.join(program_dir, "input_generation_details.txt")
with open(details_file, 'w', encoding='utf-8') as f:
f.write(f"Input Generation Details - Program {i}\n")
f.write(f"Problem ID: {problem_id}\n")
f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n")
f.write("="*80 + "\n\n")
# ๊ฐ ๋ผ์šด๋“œ์˜ ์ •๋ณด ์ €์žฅ
for round_idx, gen_info in enumerate(program['input_generation_info'], 1):
if 'error' in gen_info:
f.write(f"ROUND {round_idx} - ERROR:\n")
f.write("="*80 + "\n")
f.write(f"Error: {gen_info.get('error', 'Unknown error')}\n")
f.write(f"Traceback:\n{gen_info.get('traceback', 'No traceback')}\n")
f.write("\n")
continue
f.write(f"ROUND {round_idx}:\n")
f.write("="*80 + "\n\n")
# Function info
func_info = gen_info.get('function_info', {})
f.write("1. FUNCTION INFO:\n")
f.write("="*80 + "\n")
f.write(f"Function Name: {func_info.get('name', 'N/A')}\n")
f.write(f"Parameters: {func_info.get('args', 'N/A')}\n")
f.write(f"Parameters String: {func_info.get('signature', 'N/A')}\n\n")
# Argument type info
f.write("2. ARGUMENT TYPE INFO:\n")
f.write("="*80 + "\n")
arg_types = gen_info.get('arg_type_info', {})
if arg_types:
f.write("Argument types:\n")
for arg, arg_type in arg_types.items():
f.write(f"- {arg}: {arg_type}\n")
else:
f.write("No argument type information available\n")
f.write("\n")
# Existing examples
f.write("3. EXISTING EXAMPLES:\n")
f.write("="*80 + "\n")
existing_examples = gen_info.get('existing_examples', [])
if existing_examples:
for idx, example in enumerate(existing_examples, 1):
f.write(f"Example {idx}: Input: {example[0]} โ†’ Output: {example[1]}\n")
else:
f.write("No existing examples\n")
f.write("\n")
# LLM prompt
f.write("4. LLM PROMPT:\n")
f.write("="*80 + "\n")
f.write(gen_info.get('prompt', 'No prompt available'))
f.write("\n"*2 + "="*80 + "\n\n")
# LLM response
f.write("5. LLM RESPONSE:\n")
f.write("="*80 + "\n")
f.write(gen_info.get('llm_response', 'No response available'))
f.write("\n"*2 + "="*80 + "\n\n")
# Extracted inputs
f.write("6. EXTRACTED INPUTS:\n")
f.write("="*80 + "\n")
extracted = gen_info.get('extracted_inputs', [])
if extracted:
for idx, inp in enumerate(extracted, 1):
f.write(f"Input {idx}: {inp}\n")
else:
f.write("No inputs extracted\n")
# Valid inputs (if different from extracted)
valid = gen_info.get('valid_inputs', [])
if valid != extracted:
f.write("\n7. VALID INPUTS (after validation):\n")
f.write("="*80 + "\n")
if valid:
for idx, inp in enumerate(valid, 1):
f.write(f"Input {idx}: {inp}\n")
else:
f.write("No valid inputs after validation\n")
f.write("\n")
self.logger.log_info(f"๐Ÿ’พ Diverse programs saved to {diverse_dir} (batch evaluation format)")
except Exception as e:
self.logger.log_error(f"Failed to save diverse programs evaluation: {e}")
def _save_azr_training_data(self, all_tasks: Dict[str, List[Dict[str, Any]]],
problem_id: str, round_num: int,
output_dir: str) -> Dict[str, str]:
"""AZR ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ๋ฅผ parquet ํ˜•์‹์œผ๋กœ ์ €์žฅ"""
try:
import pandas as pd
import os
# AZR ํ•™์Šต์šฉ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
azr_dir = os.path.join(output_dir, 'azr_training_data')
os.makedirs(azr_dir, exist_ok=True)
saved_files = {}
total_tasks = 0
# Task ํƒ€์ž…๋ณ„๋กœ parquet ํŒŒ์ผ ์ €์žฅ
for task_type, tasks in all_tasks.items():
if not tasks:
continue
# AZR parquet ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
azr_data = []
for task in tasks:
# ํ”„๋กฌํ”„ํŠธ๋Š” ์ด๋ฏธ ํฌ๋งท๋œ ๋ฌธ์ž์—ด์ด๋ฏ€๋กœ ๊ทธ๋Œ€๋กœ ์ €์žฅ
# Phase 5์˜ RLHFDataset์ด ๋ฌธ์ž์—ด์„ ์ฒ˜๋ฆฌํ•˜๋„๋ก ์ˆ˜์ • ํ•„์š”
# AZR๊ณผ ๋™์ผํ•œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜: ๋ฌธ์ž์—ด โ†’ ๋”•์…”๋„ˆ๋ฆฌ ๋ฆฌ์ŠคํŠธ
# ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด RLHFDataset์—์„œ chat template์ด ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ ์šฉ๋จ
prompt_dict_list = [{"role": "user", "content": task['prompt']}]
print(f"[DEBUG AZR DATA SAVE] Converting prompt to dict list format")
print(f"[DEBUG] Original prompt type: {type(task['prompt'])}, length: {len(task['prompt']) if isinstance(task['prompt'], str) else 'N/A'}")
print(f"[DEBUG] Converted prompt type: {type(prompt_dict_list)}, first elem: {type(prompt_dict_list[0])}")
azr_record = {
'prompt': prompt_dict_list, # ๋”•์…”๋„ˆ๋ฆฌ ๋ฆฌ์ŠคํŠธ๋กœ ์ €์žฅ (AZR๊ณผ ๋™์ผ)
'uid': task['uid'],
'ipo_group_id': task['ipo_group_id'],
'source_program_id': task['source_program_id'],
'ipo_index': task['ipo_index'],
'problem': {
'input': task['ipo_triple']['input'],
'output': task['ipo_triple']['output'],
'snippet': task['ipo_triple']['program']
},
'ground_truth': task['ground_truth'],
'extra_info': task['extra_info'],
'basic_accuracy': task['basic_accuracy'],
'original_problem_id': task['original_problem_id'],
'round': task['round']
}
azr_data.append(azr_record)
# ipo_group_id๋กœ ์ •๋ ฌํ•˜์—ฌ ๋ฐฐ์น˜ ๋ณด์žฅ
azr_data.sort(key=lambda x: x['ipo_group_id'])
# Parquet ํŒŒ์ผ๋กœ ์ €์žฅ
df = pd.DataFrame(azr_data)
file_path = os.path.join(azr_dir, f'{task_type}.parquet')
df.to_parquet(file_path, index=False)
# ๋””๋ฒ„๊ทธ: ์ €์žฅ๋œ ๋ฐ์ดํ„ฐ ํ™•์ธ
print(f"[DEBUG] Saved {task_type}.parquet with {len(df)} records")
if len(df) > 0:
saved_prompt = df.iloc[0]['prompt']
print(f"[DEBUG] First saved prompt type: {type(saved_prompt)}")
saved_files[task_type] = file_path
total_tasks += len(tasks)
self.logger.log_info(f"๐Ÿ’พ Saved {len(tasks)} {task_type} tasks to {file_path}")
# ํ†ต๊ณ„ ์ •๋ณด ์ €์žฅ
stats = {
'problem_id': problem_id,
'round': round_num,
'total_tasks': total_tasks,
'tasks_by_type': {k: len(v) for k, v in all_tasks.items()},
'files': saved_files,
'batch_groups': len(set(task['ipo_group_id'] for tasks in all_tasks.values() for task in tasks))
}
stats_file = os.path.join(azr_dir, 'training_stats.json')
import json
with open(stats_file, 'w') as f:
json.dump(stats, f, indent=2)
self.logger.log_info(f"โœ… AZR training data saved: {total_tasks} tasks in {len(saved_files)} files")
self.logger.log_info(f"๐Ÿ“Š Batch groups: {stats['batch_groups']} (for batch alignment)")
return saved_files
except Exception as e:
self.logger.log_error(f"Failed to save AZR training data: {e}")
return {}