|
|
""" |
|
|
Complete TestTime RLVR Pipeline |
|
|
|
|
|
LLM ์๋ฃจ์
์์ฑ โ IPO ์ถ์ถ โ ํ์คํฌ ์์ฑ โ LLM ํ๊ฐ โ Reward ๊ณ์ฐ |
|
|
๋ชจ๋ ๋จ๊ณ์์ AZR ์ฝ๋๋ฅผ ์ต๋ํ ๊ทธ๋๋ก ํ์ฉ |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any, Optional |
|
|
import torch |
|
|
import re |
|
|
import os |
|
|
import json |
|
|
import ray |
|
|
import math |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
|
|
|
from .benchmark_loader import BenchmarkProblemLoader |
|
|
from .solution_generator import InitialSolutionGenerator |
|
|
from .ipo_extractor import IPOTripleExtractor, IPOBuffer |
|
|
from .task_generator import TestTimeTaskGenerator |
|
|
from .config import TestTimeConfig, BenchmarkConfig |
|
|
|
|
|
from .logger import TestTimeLogger |
|
|
|
|
|
|
|
|
from ..rewards.reward_managers import CodeIORewardManager |
|
|
|
|
|
|
|
|
@ray.remote |
|
|
class RemoteTestTimePipeline: |
|
|
"""Ray Actor๋ก ์๋ํ๋ TestTime Pipeline (VeRL ํจํด)""" |
|
|
|
|
|
def __init__(self, config: TestTimeConfig, model_path: str): |
|
|
"""Ray worker ๋ด๋ถ์์ ๋ชจ๋ธ ๋ก๋ฉ""" |
|
|
self.config = config |
|
|
self.model_path = model_path |
|
|
|
|
|
|
|
|
from .solution_generator import InitialSolutionGenerator |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
if 'VLLM_USE_SPECIFIC_GPUS' in os.environ: |
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['VLLM_USE_SPECIFIC_GPUS'] |
|
|
print(f"[RemoteTestTimePipeline] Restored CUDA_VISIBLE_DEVICES from VLLM_USE_SPECIFIC_GPUS: {os.environ['CUDA_VISIBLE_DEVICES']}") |
|
|
|
|
|
|
|
|
device = 'cuda:0' |
|
|
if 'CUDA_VISIBLE_DEVICES' in os.environ: |
|
|
device = f"cuda:0" |
|
|
|
|
|
|
|
|
cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0") |
|
|
print(f"[RemoteTestTimePipeline] CUDA_VISIBLE_DEVICES: {cuda_devices}") |
|
|
|
|
|
|
|
|
use_vllm = getattr(config, 'use_vllm_for_data_generation', len(cuda_devices.split(',')) > 1) |
|
|
gpu_count = len(cuda_devices.split(',')) |
|
|
|
|
|
|
|
|
vllm_tensor_parallel_size = min(2, gpu_count) if use_vllm else 1 |
|
|
|
|
|
print(f"[RemoteTestTimePipeline] GPU count: {gpu_count}, use_vllm: {use_vllm}, tensor_parallel_size: {vllm_tensor_parallel_size}") |
|
|
|
|
|
self.model, self.tokenizer = InitialSolutionGenerator.load_model_with_optimizations( |
|
|
model_path, device, config, use_vllm=use_vllm, tensor_parallel_size=vllm_tensor_parallel_size |
|
|
) |
|
|
|
|
|
|
|
|
import os |
|
|
log_file = os.environ.get('TTRLVR_LOG_FILE', None) |
|
|
if log_file: |
|
|
self.logger = TestTimeLogger(log_file=log_file) |
|
|
else: |
|
|
self.logger = TestTimeLogger() |
|
|
|
|
|
|
|
|
self.pipeline = CompleteTestTimePipeline( |
|
|
model=self.model, |
|
|
tokenizer=self.tokenizer, |
|
|
config=config, |
|
|
logger=self.logger |
|
|
) |
|
|
|
|
|
def run_complete_pipeline(self, benchmark_config: BenchmarkConfig, |
|
|
problem_id: str, round_num: int = 1, session_timestamp: str = None, |
|
|
output_base_dir: str = None) -> Dict[str, Any]: |
|
|
"""์๊ฒฉ์์ ํ์ดํ๋ผ์ธ ์คํ""" |
|
|
return self.pipeline.run_complete_pipeline(benchmark_config, problem_id, round_num, session_timestamp, output_base_dir) |
|
|
|
|
|
def generate_batch_vllm(self, prompts: List[str], max_tokens: int = 512, |
|
|
temperature: float = 0.7, top_p: float = 1.0, n: int = 1) -> Dict[str, Any]: |
|
|
""" |
|
|
VeRL์์ ํธ์ถํ ์ ์๋ ๋ฐฐ์น ์์ฑ ๋ฉ์๋ |
|
|
Step 5์ SharedVLLMRollout์์ ์ฌ์ฉ |
|
|
""" |
|
|
from .solution_generator import InitialSolutionGenerator |
|
|
|
|
|
|
|
|
if hasattr(self.model, 'generate'): |
|
|
|
|
|
outputs = InitialSolutionGenerator.generate_batch_vllm( |
|
|
self.model, prompts, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
n=n |
|
|
) |
|
|
|
|
|
|
|
|
responses = [] |
|
|
for output in outputs: |
|
|
generated_text = output.outputs[0].text |
|
|
responses.append(generated_text) |
|
|
|
|
|
|
|
|
tokenized = self.tokenizer(prompts, padding=True, truncation=True, return_tensors="pt") |
|
|
|
|
|
return { |
|
|
'responses': responses, |
|
|
'input_ids': tokenized['input_ids'].tolist(), |
|
|
'attention_mask': tokenized['attention_mask'].tolist() |
|
|
} |
|
|
else: |
|
|
|
|
|
raise NotImplementedError("HuggingFace batch generation not implemented for VeRL sharing") |
|
|
|
|
|
def update_model_weights(self, model_path: str) -> bool: |
|
|
""" |
|
|
ํ์ต๋ ๋ชจ๋ธ๋ก VLLM ๊ฐ์ค์น ์
๋ฐ์ดํธ |
|
|
๋งค ๋ผ์ด๋ ํ ํธ์ถ๋จ |
|
|
""" |
|
|
try: |
|
|
self.logger.log_info(f"๐ Updating VLLM weights from: {model_path}") |
|
|
|
|
|
|
|
|
|
|
|
from .solution_generator import InitialSolutionGenerator |
|
|
import os |
|
|
|
|
|
device = 'cuda:0' |
|
|
use_vllm = len(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(',')) > 1 |
|
|
gpu_count = len(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(',')) |
|
|
vllm_tensor_parallel_size = min(2, gpu_count) if use_vllm else 1 |
|
|
|
|
|
|
|
|
if hasattr(self.model, 'llm_engine'): |
|
|
del self.model |
|
|
import torch |
|
|
torch.cuda.empty_cache() |
|
|
self.logger.log_info(" - Old VLLM engine cleaned up") |
|
|
|
|
|
|
|
|
self.model, _ = InitialSolutionGenerator.load_model_with_optimizations( |
|
|
model_path, device, self.config, |
|
|
use_vllm=use_vllm, |
|
|
tensor_parallel_size=vllm_tensor_parallel_size |
|
|
) |
|
|
|
|
|
|
|
|
self.pipeline.model = self.model |
|
|
self.pipeline.solution_generator.model = self.model |
|
|
|
|
|
self.logger.log_info(f"โ
VLLM weights updated successfully") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"Failed to update VLLM weights: {e}") |
|
|
return False |
|
|
|
|
|
def update_model_weights_from_state_dict(self, state_dict: Dict[str, Any]) -> bool: |
|
|
""" |
|
|
State dict๋ก ์ง์ ๊ฐ์ค์น ์
๋ฐ์ดํธ (๋ ํจ์จ์ ) |
|
|
VeRL์์ ํ์ต๋ ๊ฐ์ค์น๋ฅผ ์ง์ ์ ๋ฌ๋ฐ์ ์
๋ฐ์ดํธ |
|
|
""" |
|
|
try: |
|
|
self.logger.log_info("๐ Updating VLLM weights from state dict") |
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(self.model, 'llm_engine'): |
|
|
|
|
|
self.model.load_state_dict(state_dict) |
|
|
self.logger.log_info("โ
Model weights updated via state dict") |
|
|
return True |
|
|
else: |
|
|
|
|
|
self.logger.log_warning("โ ๏ธ VLLM requires file-based weight loading") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"Failed to update weights from state dict: {e}") |
|
|
return False |
|
|
|
|
|
def cleanup(self): |
|
|
"""Ray Actor ์ข
๋ฃ ์ ๋ฆฌ์์ค ์ ๋ฆฌ""" |
|
|
try: |
|
|
self.logger.log_info("๐งน Cleaning up RemoteTestTimePipeline resources...") |
|
|
|
|
|
|
|
|
if hasattr(self, 'model') and self.model is not None: |
|
|
self.logger.log_info(" - Cleaning up VLLM model...") |
|
|
|
|
|
del self.model |
|
|
self.model = None |
|
|
|
|
|
|
|
|
if hasattr(self, 'pipeline') and self.pipeline is not None: |
|
|
if hasattr(self.pipeline, 'cleanup'): |
|
|
self.pipeline.cleanup() |
|
|
del self.pipeline |
|
|
self.pipeline = None |
|
|
|
|
|
|
|
|
import torch |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
import gc |
|
|
gc.collect() |
|
|
|
|
|
self.logger.log_info("โ
Cleanup completed") |
|
|
return True |
|
|
except Exception as e: |
|
|
self.logger.log_error(f"โ ๏ธ Cleanup error: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
class CompleteTestTimePipeline: |
|
|
"""์์ ํ TestTime RLVR ํ์ดํ๋ผ์ธ""" |
|
|
|
|
|
def __init__(self, model, tokenizer, config: TestTimeConfig, |
|
|
logger: Optional[TestTimeLogger] = None): |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
self.config = config |
|
|
|
|
|
self.logger = logger or TestTimeLogger() |
|
|
|
|
|
|
|
|
self.benchmark_loader = BenchmarkProblemLoader(config, self.logger) |
|
|
|
|
|
|
|
|
if model is not None and tokenizer is not None: |
|
|
|
|
|
use_vllm = getattr(config, 'use_vllm_for_data_generation', True) |
|
|
self.solution_generator = InitialSolutionGenerator(model, tokenizer, config, self.logger, use_vllm=use_vllm) |
|
|
self.ipo_extractor = IPOTripleExtractor(config, self.logger, model, tokenizer) |
|
|
|
|
|
self.ipo_extractor.solution_generator = self.solution_generator |
|
|
self.reward_manager = self._setup_azr_reward_manager() |
|
|
else: |
|
|
|
|
|
self.solution_generator = None |
|
|
self.ipo_extractor = None |
|
|
self.reward_manager = None |
|
|
|
|
|
self.task_generator = TestTimeTaskGenerator(config, self.logger) |
|
|
|
|
|
|
|
|
self.execution_mode = "single_gpu" |
|
|
self.available_gpus = [] |
|
|
|
|
|
|
|
|
self.ipo_buffer = IPOBuffer() |
|
|
|
|
|
|
|
|
self.task_output_dir = Path('./tmp/batch_results') |
|
|
|
|
|
def _ensure_models_loaded(self): |
|
|
"""๋ชจ๋ธ๊ณผ ์ปดํฌ๋ํธ๋ค์ด ๋ก๋๋์๋์ง ํ์ธํ๊ณ ํ์์ ์ด๊ธฐํ""" |
|
|
|
|
|
if self.model is None or self.tokenizer is None: |
|
|
raise RuntimeError("Model and tokenizer must be provided during initialization") |
|
|
|
|
|
|
|
|
if self.solution_generator is None: |
|
|
|
|
|
use_vllm = getattr(self.config, 'use_vllm_for_data_generation', True) |
|
|
self.logger.log_info(f"๐ง Initializing solution generator with use_vllm={use_vllm}") |
|
|
try: |
|
|
self.solution_generator = InitialSolutionGenerator( |
|
|
self.model, self.tokenizer, self.config, self.logger, use_vllm=use_vllm |
|
|
) |
|
|
self.logger.log_info(f"โ
Solution generator initialized successfully") |
|
|
except Exception as e: |
|
|
self.logger.log_error(f"โ Failed to initialize solution generator: {e}") |
|
|
import traceback |
|
|
self.logger.log_error(f"Traceback: {traceback.format_exc()}") |
|
|
raise |
|
|
|
|
|
if self.ipo_extractor is None: |
|
|
self.ipo_extractor = IPOTripleExtractor( |
|
|
self.config, self.logger, self.model, self.tokenizer |
|
|
) |
|
|
|
|
|
self.ipo_extractor.solution_generator = self.solution_generator |
|
|
|
|
|
if self.reward_manager is None: |
|
|
self.reward_manager = self._setup_azr_reward_manager() |
|
|
|
|
|
self.logger.log_info("โ
All components ready") |
|
|
|
|
|
def set_execution_mode(self, execution_mode: str, available_gpus: List[int]): |
|
|
"""์คํ ๋ชจ๋ ์ค์ (iterative_trainer์์ ํธ์ถ)""" |
|
|
self.execution_mode = execution_mode |
|
|
self.available_gpus = available_gpus |
|
|
|
|
|
self.logger.log_info(f"๐ฏ Execution mode set to: {execution_mode}") |
|
|
self.logger.log_info(f"๐ฏ Available GPUs: {available_gpus}") |
|
|
|
|
|
|
|
|
def _setup_azr_reward_manager(self) -> CodeIORewardManager: |
|
|
"""AZR Reward Manager ์ค์ (๊ธฐ์กด ์ค์ ๊ทธ๋๋ก ์ฌ์ฉ)""" |
|
|
|
|
|
|
|
|
class SimpleConfig: |
|
|
def __init__(self): |
|
|
self.use_original_code_as_ref = False |
|
|
self.reward_type = 'code_execution' |
|
|
self.weight = 1.0 |
|
|
|
|
|
reward_manager = CodeIORewardManager( |
|
|
tokenizer=self.tokenizer, |
|
|
num_examine=0, |
|
|
reward_fn_extraction_type='rule', |
|
|
math_metric='accuracy', |
|
|
split='test', |
|
|
splitter='boxed', |
|
|
output_path='./testtime_output', |
|
|
max_prompt_length=1024, |
|
|
generation_reward_config=SimpleConfig() |
|
|
) |
|
|
|
|
|
return reward_manager |
|
|
|
|
|
def run_complete_pipeline(self, benchmark_config: BenchmarkConfig, |
|
|
problem_id: str, round_num: int = 1, session_timestamp: str = None, |
|
|
output_base_dir: str = None) -> Dict[str, Any]: |
|
|
"""์์ ํ ํ์ดํ๋ผ์ธ ์คํ |
|
|
|
|
|
Args: |
|
|
benchmark_config: ๋ฒค์น๋งํฌ ์ค์ |
|
|
problem_id: ๋ฌธ์ ID |
|
|
round_num: ๋ผ์ด๋ ๋ฒํธ |
|
|
session_timestamp: ์ธ์
ํ์์คํฌํ |
|
|
output_base_dir: ๋ก๊ทธ ์ ์ฅ ๊ธฐ๋ณธ ๋๋ ํ ๋ฆฌ (None์ด๋ฉด ๊ธฐ๋ณธ ๊ฒฝ๋ก ์ฌ์ฉ) |
|
|
""" |
|
|
|
|
|
|
|
|
if session_timestamp is None: |
|
|
|
|
|
session_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
|
|
|
benchmark_safe = benchmark_config.name |
|
|
problem_safe = problem_id.replace('/', '_') |
|
|
|
|
|
|
|
|
if output_base_dir: |
|
|
round_log_dir = os.path.join(output_base_dir, benchmark_safe, problem_safe, f'round_{round_num}') |
|
|
else: |
|
|
round_log_dir = f'/home/ubuntu/RLVR/TestTime-RLVR-v2/tmp/batch_results/ttrlvr_azr_{session_timestamp}/{benchmark_safe}/{problem_safe}/round_{round_num}' |
|
|
|
|
|
|
|
|
self.logger = TestTimeLogger(task_output_dir=round_log_dir) |
|
|
|
|
|
self.logger.log_info(f"๐ Starting complete TestTime RLVR pipeline for {problem_id}") |
|
|
|
|
|
|
|
|
self._ensure_models_loaded() |
|
|
|
|
|
pipeline_result = { |
|
|
'problem_id': problem_id, |
|
|
'benchmark': benchmark_config.name, |
|
|
'round': round_num, |
|
|
'output_dir': round_log_dir, |
|
|
'steps': {}, |
|
|
'success': False, |
|
|
'error': None |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
self.logger.log_info("๐ Step 1: Loading benchmark problem") |
|
|
problem = self.benchmark_loader.load_problem(benchmark_config, problem_id) |
|
|
pipeline_result['steps']['problem_loading'] = { |
|
|
'success': True, |
|
|
'problem': problem |
|
|
} |
|
|
|
|
|
|
|
|
self.logger.log_info("๐ Step 1.5: Baseline performance evaluation") |
|
|
baseline_results = self._evaluate_baseline_performance(problem) |
|
|
pipeline_result['steps']['baseline_evaluation'] = baseline_results |
|
|
|
|
|
|
|
|
self.logger.log_info(f"๐ Clearing IPO buffer for round {round_num}") |
|
|
self.ipo_buffer.clear(problem_id) |
|
|
|
|
|
|
|
|
diverse_programs_results = self._generate_diverse_programs_and_ipo(problem) |
|
|
pipeline_result['steps']['diverse_programs'] = diverse_programs_results |
|
|
|
|
|
|
|
|
self._save_diverse_programs_evaluation(problem, diverse_programs_results) |
|
|
|
|
|
if not diverse_programs_results['success']: |
|
|
self.logger.log_error(f"โ No valid diverse programs generated") |
|
|
pipeline_result['error'] = "No valid diverse programs could be generated" |
|
|
return pipeline_result |
|
|
|
|
|
|
|
|
self.logger.log_info("๐ฏ Step 3: Generating tasks from current round IPO triples") |
|
|
current_round_triples = self.ipo_buffer.get_all(problem_id) |
|
|
self.logger.log_info(f"๐ฏ Using {len(current_round_triples)} IPO triples from current round") |
|
|
all_tasks = self.task_generator.generate_tasks(current_round_triples, problem_id, round_num) |
|
|
|
|
|
total_tasks = sum(len(tasks) for tasks in all_tasks.values()) |
|
|
pipeline_result['steps']['task_generation'] = { |
|
|
'success': total_tasks > 0, |
|
|
'total_tasks': total_tasks, |
|
|
'tasks_by_type': {k: len(v) for k, v in all_tasks.items()}, |
|
|
'all_tasks': all_tasks |
|
|
} |
|
|
|
|
|
if getattr(self.config, 'skip_task_evaluation', False): |
|
|
self.logger.log_info("โญ๏ธ Step 4: Skipping task evaluation (fast mode)") |
|
|
task_evaluations = {task_type: [] for task_type in all_tasks.keys()} |
|
|
pipeline_result['steps']['task_evaluation'] = { |
|
|
'success': True, |
|
|
'skipped': True, |
|
|
'evaluations': task_evaluations |
|
|
} |
|
|
else: |
|
|
self.logger.log_info("๐ญ Step 4: Evaluating tasks with LLM") |
|
|
task_evaluations = self._evaluate_tasks_with_llm(all_tasks) |
|
|
|
|
|
pipeline_result['steps']['task_evaluation'] = { |
|
|
'success': True, |
|
|
'evaluations': task_evaluations |
|
|
} |
|
|
|
|
|
|
|
|
self.logger.log_info("๐ Step 5: Computing rewards") |
|
|
buffered_triples = self.ipo_buffer.get_all(problem_id) |
|
|
rewards = self._compute_rewards_with_azr(task_evaluations, buffered_triples) |
|
|
|
|
|
pipeline_result['steps']['reward_computation'] = { |
|
|
'success': True, |
|
|
'rewards': rewards |
|
|
} |
|
|
|
|
|
|
|
|
self.logger.log_info("๐พ Step 6: Saving AZR training data") |
|
|
output_dir = pipeline_result.get('output_dir', './testtime_output') |
|
|
azr_files = self._save_azr_training_data(all_tasks, problem_id, round_num, output_dir) |
|
|
|
|
|
pipeline_result['steps']['azr_data_saving'] = { |
|
|
'success': len(azr_files) > 0, |
|
|
'files': azr_files, |
|
|
'total_tasks': sum(len(tasks) for tasks in all_tasks.values()) |
|
|
} |
|
|
|
|
|
|
|
|
self.logger.log_info("๐ Step 7: Generating task summary") |
|
|
self._save_task_summary_json(problem, baseline_results, task_evaluations, round_num) |
|
|
|
|
|
|
|
|
pipeline_result['success'] = True |
|
|
pipeline_result['azr_training_data'] = azr_files |
|
|
self.logger.log_info("โ
Complete pipeline executed successfully") |
|
|
|
|
|
return pipeline_result |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"๐ฅ Pipeline failed: {e}") |
|
|
pipeline_result['error'] = str(e) |
|
|
return pipeline_result |
|
|
|
|
|
finally: |
|
|
|
|
|
self.ipo_extractor.cleanup() |
|
|
|
|
|
def _evaluate_tasks_with_llm(self, all_tasks: Dict[str, List[Dict[str, Any]]]) -> Dict[str, List[Dict[str, Any]]]: |
|
|
"""LLM์ผ๋ก ์์ฑ๋ ํ์คํฌ๋ค ํ๊ฐํ๊ณ basic_accuracy ์
๋ฐ์ดํธ""" |
|
|
|
|
|
evaluations = {} |
|
|
|
|
|
|
|
|
from ..utils.code_utils.python_executor import PythonExecutor |
|
|
executor = PythonExecutor() |
|
|
|
|
|
for task_type, tasks in all_tasks.items(): |
|
|
self.logger.log_info(f"๐ Evaluating {len(tasks)} {task_type} tasks") |
|
|
task_evaluations = [] |
|
|
|
|
|
for task in tasks: |
|
|
|
|
|
task_prompt = task['prompt'] |
|
|
|
|
|
|
|
|
llm_response = self._generate_task_response(task_prompt) |
|
|
|
|
|
|
|
|
evaluation = { |
|
|
'task_id': task['task_id'], |
|
|
'task_type': task_type, |
|
|
'prompt': task_prompt, |
|
|
'llm_response': llm_response, |
|
|
'expected_solution': task['expected_solution'], |
|
|
'evaluation_data': task['evaluation_data'] |
|
|
} |
|
|
|
|
|
|
|
|
accuracy = self._calculate_task_accuracy(evaluation, task_type, executor) |
|
|
task['basic_accuracy'] = accuracy |
|
|
evaluation['basic_accuracy'] = accuracy |
|
|
|
|
|
task_evaluations.append(evaluation) |
|
|
|
|
|
evaluations[task_type] = task_evaluations |
|
|
|
|
|
|
|
|
self._save_llm_responses(task_type, task_evaluations) |
|
|
|
|
|
return evaluations |
|
|
|
|
|
def _generate_task_response(self, prompt: str) -> str: |
|
|
"""๋จ์ผ ํ์คํฌ์ ๋ํ LLM ์๋ต ์์ฑ (AZR ๋ฐฉ์)""" |
|
|
|
|
|
|
|
|
try: |
|
|
from vllm import LLM |
|
|
if isinstance(self.model, LLM): |
|
|
|
|
|
from vllm import SamplingParams |
|
|
|
|
|
sampling_params = SamplingParams( |
|
|
temperature=0.05, |
|
|
max_tokens=512, |
|
|
top_p=0.95, |
|
|
stop=["\n\n\n", "# Task:", "================================================================================"] |
|
|
) |
|
|
|
|
|
outputs = self.model.generate([prompt], sampling_params, use_tqdm=False) |
|
|
response = outputs[0].outputs[0].text.replace("\t", " ") |
|
|
return response.strip() |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096) |
|
|
|
|
|
|
|
|
if 'attention_mask' not in inputs: |
|
|
inputs['attention_mask'] = torch.ones_like(inputs['input_ids']) |
|
|
|
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
outputs = self.model.generate( |
|
|
inputs['input_ids'], |
|
|
attention_mask=inputs['attention_mask'], |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.05, |
|
|
top_p=0.95, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
response = response[len(prompt):].strip() |
|
|
|
|
|
return response |
|
|
|
|
|
def _calculate_task_accuracy(self, evaluation: Dict[str, Any], task_type: str, executor) -> float: |
|
|
"""๊ฐ๋ณ task์ ์ ํ๋ ๊ณ์ฐ""" |
|
|
|
|
|
try: |
|
|
llm_response = evaluation['llm_response'] |
|
|
expected = evaluation['expected_solution'] |
|
|
evaluation_data = evaluation['evaluation_data'] |
|
|
|
|
|
|
|
|
extracted_answer = self._extract_answer_by_task_type(llm_response, task_type) |
|
|
|
|
|
if task_type == 'abduction': |
|
|
|
|
|
code = evaluation_data['function_code'] |
|
|
expected_output_value = evaluation_data['expected_output'] |
|
|
agent_input = extracted_answer |
|
|
|
|
|
try: |
|
|
|
|
|
import re |
|
|
func_name_match = re.search(r'def\s+(\w+)\s*\(', code) |
|
|
if func_name_match: |
|
|
original_func_name = func_name_match.group(1) |
|
|
|
|
|
code = re.sub(r'def\s+' + re.escape(original_func_name) + r'\s*\(', 'def f(', code) |
|
|
|
|
|
from ..utils.code_utils.templates import EVAL_INPUT_PREDICTION_TEMPLATE |
|
|
code_snippet = EVAL_INPUT_PREDICTION_TEMPLATE.format( |
|
|
code=code, |
|
|
gold_output=expected_output_value, |
|
|
agent_input=agent_input |
|
|
) |
|
|
result, status = executor.apply(code_snippet) |
|
|
|
|
|
if 'error' in status.lower(): |
|
|
accuracy = 0.0 |
|
|
else: |
|
|
try: |
|
|
if isinstance(result, bool): |
|
|
agent_output = result |
|
|
else: |
|
|
agent_output = eval(result) |
|
|
accuracy = 1.0 if agent_output else 0.0 |
|
|
except: |
|
|
accuracy = 0.0 |
|
|
except: |
|
|
accuracy = 0.0 |
|
|
|
|
|
elif task_type == 'deduction': |
|
|
|
|
|
expected_output = expected |
|
|
agent_output = extracted_answer |
|
|
|
|
|
try: |
|
|
accuracy = 1.0 if eval(expected_output) == eval(agent_output) else 0.0 |
|
|
except: |
|
|
accuracy = 0.0 |
|
|
|
|
|
elif task_type == 'induction': |
|
|
|
|
|
input_output_pairs = evaluation_data['input_output_pairs'] |
|
|
agent_code = extracted_answer |
|
|
|
|
|
accuracies = [] |
|
|
for test_input, expected_output in input_output_pairs: |
|
|
try: |
|
|
accuracy = executor.eval_input_prediction(agent_code, expected_output, test_input) |
|
|
accuracies.append(accuracy if accuracy is not None else 0.0) |
|
|
except: |
|
|
accuracies.append(0.0) |
|
|
|
|
|
|
|
|
accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0 |
|
|
|
|
|
else: |
|
|
|
|
|
accuracy = 1.0 if expected.strip() == extracted_answer.strip() else 0.0 |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"Error calculating accuracy for {task_type}: {e}") |
|
|
accuracy = 0.0 |
|
|
|
|
|
return accuracy |
|
|
|
|
|
def _extract_answer_by_task_type(self, llm_response: str, task_type: str) -> str: |
|
|
"""ํ์คํฌ ํ์
๋ณ AZR ๋ฐฉ์ ์ ๋ต ์ถ์ถ""" |
|
|
|
|
|
if task_type == 'induction': |
|
|
|
|
|
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE) |
|
|
matches = pattern.findall(llm_response) |
|
|
return matches[-1].strip() if matches else llm_response.strip() |
|
|
|
|
|
elif task_type == 'abduction': |
|
|
|
|
|
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE) |
|
|
matches = pattern.findall(llm_response) |
|
|
return matches[-1].strip() if matches else llm_response.strip() |
|
|
|
|
|
elif task_type == 'deduction': |
|
|
|
|
|
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE) |
|
|
matches = pattern.findall(llm_response) |
|
|
return matches[-1].strip() if matches else llm_response.strip() |
|
|
|
|
|
else: |
|
|
|
|
|
return llm_response.strip() |
|
|
|
|
|
def _compute_rewards_with_azr(self, task_evaluations: Dict[str, List[Dict[str, Any]]], |
|
|
ipo_triples: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
|
"""AZR Reward Manager๋ก ๋ณด์ ๊ณ์ฐ (์ค์ ์ฝ๋ ์คํ ๊ธฐ๋ฐ ํ๊ฐ)""" |
|
|
|
|
|
|
|
|
from ..utils.code_utils.python_executor import PythonExecutor |
|
|
executor = PythonExecutor() |
|
|
|
|
|
rewards_by_type = {} |
|
|
total_rewards = [] |
|
|
|
|
|
for task_type, evaluations in task_evaluations.items(): |
|
|
self.logger.log_info(f"๐ฏ Computing rewards for {task_type} tasks") |
|
|
|
|
|
type_rewards = [] |
|
|
|
|
|
for evaluation in evaluations: |
|
|
expected = evaluation['expected_solution'] |
|
|
llm_response = evaluation['llm_response'] |
|
|
evaluation_data = evaluation['evaluation_data'] |
|
|
|
|
|
|
|
|
extracted_answer = self._extract_answer_by_task_type(llm_response, task_type) |
|
|
|
|
|
|
|
|
try: |
|
|
if task_type == 'abduction': |
|
|
|
|
|
code = evaluation_data['function_code'] |
|
|
expected_output = evaluation_data['expected_output'] |
|
|
agent_input = extracted_answer |
|
|
|
|
|
|
|
|
import re |
|
|
def extract_function_definition(code): |
|
|
"""์ฝ๋์์ import๋ฌธ๊ณผ ํจ์ ์ ์๋ฅผ ์ถ์ถ""" |
|
|
lines = code.split('\n') |
|
|
import_lines = [] |
|
|
func_lines = [] |
|
|
in_function = False |
|
|
base_indent = None |
|
|
|
|
|
for line in lines: |
|
|
|
|
|
if line.strip().startswith('from ') or line.strip().startswith('import '): |
|
|
import_lines.append(line) |
|
|
|
|
|
elif line.strip().startswith('def '): |
|
|
in_function = True |
|
|
base_indent = len(line) - len(line.lstrip()) |
|
|
func_lines.append(line) |
|
|
elif in_function: |
|
|
|
|
|
if line.strip() == '': |
|
|
func_lines.append(line) |
|
|
elif line.startswith(' ' * (base_indent + 1)) or line.startswith('\t'): |
|
|
|
|
|
func_lines.append(line) |
|
|
else: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
if import_lines: |
|
|
return '\n'.join(import_lines) + '\n\n' + '\n'.join(func_lines) |
|
|
else: |
|
|
return '\n'.join(func_lines) |
|
|
|
|
|
|
|
|
code = extract_function_definition(code) |
|
|
|
|
|
|
|
|
|
|
|
func_name_match = re.search(r'def\s+(\w+)\s*\(', code) |
|
|
if func_name_match: |
|
|
original_func_name = func_name_match.group(1) |
|
|
|
|
|
code = re.sub(r'def\s+' + re.escape(original_func_name) + r'\s*\(', 'def f(', code) |
|
|
|
|
|
|
|
|
try: |
|
|
expected_output_value = eval(expected_output) |
|
|
except: |
|
|
expected_output_value = expected_output |
|
|
|
|
|
|
|
|
try: |
|
|
from ..utils.code_utils.templates import EVAL_INPUT_PREDICTION_TEMPLATE |
|
|
code_snippet = EVAL_INPUT_PREDICTION_TEMPLATE.format( |
|
|
code=code, |
|
|
gold_output=expected_output_value, |
|
|
agent_input=agent_input |
|
|
) |
|
|
result, status = executor.apply(code_snippet) |
|
|
|
|
|
if 'error' in status.lower(): |
|
|
accuracy = 0.0 |
|
|
else: |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(result, bool): |
|
|
|
|
|
agent_output = result |
|
|
else: |
|
|
|
|
|
agent_output = eval(result) |
|
|
accuracy = 1.0 if agent_output else 0.0 |
|
|
except: |
|
|
accuracy = 0.0 |
|
|
except: |
|
|
accuracy = 0.0 |
|
|
|
|
|
elif task_type == 'deduction': |
|
|
|
|
|
expected_output = expected |
|
|
agent_output = extracted_answer |
|
|
|
|
|
|
|
|
try: |
|
|
accuracy = 1.0 if eval(expected_output) == eval(agent_output) else 0.0 |
|
|
except: |
|
|
accuracy = 0.0 |
|
|
|
|
|
elif task_type == 'induction': |
|
|
|
|
|
input_output_pairs = evaluation_data['input_output_pairs'] |
|
|
agent_code = extracted_answer |
|
|
|
|
|
|
|
|
accuracies = [] |
|
|
for test_input, expected_output in input_output_pairs: |
|
|
try: |
|
|
accuracy = executor.eval_input_prediction(agent_code, expected_output, test_input) |
|
|
accuracies.append(accuracy if accuracy is not None else 0.0) |
|
|
except: |
|
|
accuracies.append(0.0) |
|
|
|
|
|
|
|
|
accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0 |
|
|
|
|
|
else: |
|
|
|
|
|
accuracy = 1.0 if expected.strip() == extracted_answer.strip() else 0.0 |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"Error in {task_type} evaluation: {e}") |
|
|
accuracy = 0.0 |
|
|
|
|
|
|
|
|
reward = { |
|
|
'task_id': evaluation['task_id'], |
|
|
'task_type': task_type, |
|
|
'extracted_answer': extracted_answer, |
|
|
'expected_solution': expected, |
|
|
'basic_accuracy': accuracy, |
|
|
'final_reward': accuracy |
|
|
} |
|
|
|
|
|
type_rewards.append(reward) |
|
|
total_rewards.append(reward['final_reward']) |
|
|
|
|
|
rewards_by_type[task_type] = type_rewards |
|
|
|
|
|
|
|
|
avg_reward = sum(total_rewards) / len(total_rewards) if total_rewards else 0.0 |
|
|
|
|
|
return { |
|
|
'rewards_by_type': rewards_by_type, |
|
|
'total_tasks': len(total_rewards), |
|
|
'average_reward': avg_reward, |
|
|
'reward_distribution': { |
|
|
task_type: sum(r['final_reward'] for r in rewards) / len(rewards) if rewards else 0.0 |
|
|
for task_type, rewards in rewards_by_type.items() |
|
|
} |
|
|
} |
|
|
|
|
|
def _compute_similarity(self, expected: str, actual: str) -> float: |
|
|
"""๋ฌธ์์ด ์ ์ฌ์ฑ ๊ณ์ฐ (๊ฐ๋จํ ๋ฐฉ์)""" |
|
|
|
|
|
expected_words = set(expected.lower().split()) |
|
|
actual_words = set(actual.lower().split()) |
|
|
|
|
|
if not expected_words and not actual_words: |
|
|
return 1.0 |
|
|
if not expected_words or not actual_words: |
|
|
return 0.0 |
|
|
|
|
|
intersection = expected_words & actual_words |
|
|
union = expected_words | actual_words |
|
|
|
|
|
return len(intersection) / len(union) |
|
|
|
|
|
def _evaluate_baseline_performance(self, problem: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""๋ฒ ์ด์ค๋ผ์ธ ์ฑ๋ฅ ์ธก์ (temperature=0.05๋ก 5๋ฒ ์คํ)""" |
|
|
|
|
|
self.logger.log_info(f"๐ Evaluating baseline performance for {problem.get('task_id', 'unknown')}") |
|
|
|
|
|
baseline_results = { |
|
|
'success': True, |
|
|
'total_rounds': self.config.baseline_evaluation_rounds, |
|
|
'solutions': [], |
|
|
'evaluations': [], |
|
|
'success_count': 0, |
|
|
'average_accuracy': 0.0, |
|
|
'error': None |
|
|
} |
|
|
|
|
|
try: |
|
|
for round_id in range(self.config.baseline_evaluation_rounds): |
|
|
self.logger.log_info(f" ๐ Baseline round {round_id + 1}/{self.config.baseline_evaluation_rounds}") |
|
|
|
|
|
|
|
|
solution = self.solution_generator.generate(problem) |
|
|
|
|
|
|
|
|
is_valid, syntax_error = self.solution_generator.validate_syntax(solution) |
|
|
|
|
|
solution_result = { |
|
|
'round_id': round_id, |
|
|
'solution': solution, |
|
|
'syntax_valid': is_valid, |
|
|
'syntax_error': syntax_error, |
|
|
'evaluation': None |
|
|
} |
|
|
|
|
|
|
|
|
if is_valid: |
|
|
evaluation = self.solution_generator.evaluate_solution(problem, solution) |
|
|
solution_result['evaluation'] = evaluation |
|
|
|
|
|
if evaluation['correct']: |
|
|
baseline_results['success_count'] += 1 |
|
|
self.logger.log_info(f" โ
Round {round_id + 1}: PASSED ({evaluation['passed_tests']}/{evaluation['total_tests']} tests)") |
|
|
|
|
|
|
|
|
self.logger.log_problem_attempt(problem, solution, True, evaluation) |
|
|
else: |
|
|
self.logger.log_info(f" โ Round {round_id + 1}: FAILED ({evaluation['passed_tests']}/{evaluation['total_tests']} tests)") |
|
|
|
|
|
|
|
|
self.logger.log_problem_attempt(problem, solution, False, evaluation) |
|
|
else: |
|
|
self.logger.log_warning(f" โ Round {round_id + 1}: Syntax error - {syntax_error}") |
|
|
|
|
|
|
|
|
syntax_validation = {'syntax_valid': False, 'syntax_error': syntax_error} |
|
|
self.logger.log_problem_attempt(problem, solution, False, syntax_validation) |
|
|
|
|
|
baseline_results['solutions'].append(solution_result) |
|
|
|
|
|
|
|
|
self._save_batch_evaluation_format(problem, solution_result, attempt_num=round_id + 1) |
|
|
|
|
|
|
|
|
if baseline_results['success_count'] > 0: |
|
|
baseline_results['average_accuracy'] = baseline_results['success_count'] / baseline_results['total_rounds'] |
|
|
|
|
|
self.logger.log_info(f" ๐ Baseline performance: {baseline_results['success_count']}/{baseline_results['total_rounds']} success ({baseline_results['average_accuracy']:.3f})") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"โ Baseline evaluation failed: {e}") |
|
|
baseline_results['success'] = False |
|
|
baseline_results['error'] = str(e) |
|
|
|
|
|
return baseline_results |
|
|
|
|
|
def _save_batch_evaluation_format(self, problem: Dict[str, Any], solution_result: Dict[str, Any], attempt_num: int): |
|
|
"""Batch evaluation๊ณผ ๋์ผํ ํ์์ผ๋ก ์์ธ ๋ก๊ทธ ์ ์ฅ""" |
|
|
|
|
|
from ..testtime.prompts import get_prompt |
|
|
|
|
|
|
|
|
current_dir = os.path.join(self.logger.log_dir, "current_evaluation") |
|
|
os.makedirs(current_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
attempt_file = os.path.join(current_dir, f"attempt_{attempt_num}.txt") |
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
|
|
|
with open(attempt_file, 'w', encoding='utf-8') as f: |
|
|
|
|
|
f.write(f"Current Evaluation - Attempt {attempt_num}\n") |
|
|
f.write(f"Problem ID: {problem.get('task_id', 'unknown')}\n") |
|
|
f.write(f"Benchmark: {problem.get('benchmark_name', 'unknown')}\n") |
|
|
f.write(f"Generated: {timestamp}\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
|
|
|
|
|
|
f.write("1. ORIGINAL PROBLEM:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(problem.get('prompt', 'No prompt available')) |
|
|
f.write("\n" + "="*80 + "\n\n") |
|
|
|
|
|
|
|
|
f.write("2. LLM INPUT SCRIPT (PROMPT):\n") |
|
|
f.write("="*80 + "\n") |
|
|
problem_prompt = problem.get('prompt', '') |
|
|
|
|
|
|
|
|
try: |
|
|
if 'HumanEval' in problem.get('task_id', ''): |
|
|
full_prompt = get_prompt("solution_humaneval_basic", |
|
|
problem_prompt=problem_prompt) |
|
|
else: |
|
|
full_prompt = get_prompt("solution_mbpp_basic", |
|
|
problem_prompt=problem_prompt) |
|
|
f.write(full_prompt.strip()) |
|
|
except Exception as e: |
|
|
|
|
|
f.write(f"You are a Python writing assistant. Complete the following Python function.\n\n{problem_prompt}\n\nPlease provide a complete implementation of the function.") |
|
|
|
|
|
f.write("\n" + "="*80 + "\n\n") |
|
|
|
|
|
|
|
|
f.write("3. LLM RESPONSE:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(solution_result.get('solution', 'No solution generated')) |
|
|
f.write("\n" + "="*80 + "\n\n") |
|
|
|
|
|
|
|
|
f.write("4. CORRECTNESS EVALUATION:\n") |
|
|
f.write("="*80 + "\n") |
|
|
|
|
|
|
|
|
f.write(f"Syntax Valid: {'โ
YES' if solution_result.get('syntax_valid', False) else 'โ NO'}\n") |
|
|
if solution_result.get('syntax_error'): |
|
|
f.write(f"Syntax Error: {solution_result['syntax_error']}\n") |
|
|
|
|
|
|
|
|
evaluation = solution_result.get('evaluation') |
|
|
if evaluation: |
|
|
if evaluation.get('correct', False): |
|
|
f.write(f"Result: โ
CORRECT ({evaluation.get('passed_tests', 0)}/{evaluation.get('total_tests', 0)} tests passed)\n") |
|
|
else: |
|
|
f.write(f"Result: โ INCORRECT ({evaluation.get('passed_tests', 0)}/{evaluation.get('total_tests', 0)} tests passed)\n") |
|
|
|
|
|
if evaluation.get('error'): |
|
|
f.write(f"Evaluation Error: {evaluation['error']}\n") |
|
|
else: |
|
|
f.write("Result: โ NO EVALUATION (syntax error or evaluation failed)\n") |
|
|
|
|
|
f.write("="*80 + "\n") |
|
|
|
|
|
self.logger.log_info(f"๐ Batch evaluation format saved: {attempt_file}") |
|
|
|
|
|
def _save_llm_responses(self, task_type: str, evaluations: List[Dict[str, Any]]): |
|
|
"""LLM ์๋ต์ llm_responses ๋๋ ํ ๋ฆฌ์ ์ ์ฅ (batch evaluation๊ณผ ๋์ผํ ํ์)""" |
|
|
|
|
|
try: |
|
|
|
|
|
llm_dir = os.path.join(self.logger.log_dir, "llm_responses") |
|
|
os.makedirs(llm_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
for i, evaluation in enumerate(evaluations, 1): |
|
|
problem_id = evaluation['task_id'].split('_')[0] if '_' in evaluation['task_id'] else evaluation['task_id'] |
|
|
response_file = os.path.join(llm_dir, f"{problem_id}_{task_type}_{i}_response.txt") |
|
|
|
|
|
with open(response_file, 'w', encoding='utf-8') as f: |
|
|
f.write(f"Task Type: {task_type}\n") |
|
|
f.write(f"Task ID: {evaluation['task_id']}\n") |
|
|
f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") |
|
|
f.write("="*80 + "\nORIGINAL PROMPT:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(evaluation['prompt']) |
|
|
f.write("\n" + "="*80 + "\n") |
|
|
f.write("LLM RESPONSE:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(evaluation['llm_response']) |
|
|
f.write("\n" + "="*80 + "\n") |
|
|
f.write("EXPECTED SOLUTION:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(str(evaluation['expected_solution'])) |
|
|
f.write("\n" + "="*80 + "\n") |
|
|
f.write("EXTRACTED ANSWER:\n") |
|
|
f.write("="*80 + "\n") |
|
|
|
|
|
extracted_answer = self._extract_answer_by_task_type(evaluation['llm_response'], task_type) |
|
|
f.write(extracted_answer) |
|
|
f.write("\n" + "="*80 + "\n") |
|
|
f.write("MATCH RESULT:\n") |
|
|
f.write("="*80 + "\n") |
|
|
accuracy = evaluation.get('basic_accuracy', 0.0) |
|
|
if accuracy > 0.5: |
|
|
f.write(f"โ
CORRECT (Score: {accuracy:.3f})") |
|
|
else: |
|
|
f.write(f"โ INCORRECT (Score: {accuracy:.3f})") |
|
|
|
|
|
self.logger.log_info(f"๐ LLM responses saved to {llm_dir} (batch evaluation format)") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_warning(f"Failed to save LLM responses: {e}") |
|
|
|
|
|
def _save_task_summary_json(self, problem: Dict[str, Any], baseline_results: Dict[str, Any], |
|
|
task_evaluations: Dict[str, List[Dict[str, Any]]], round_num: int = None): |
|
|
"""batch evaluation๊ณผ ๋์ผํ ํ์์ summary.json ์์ฑ""" |
|
|
|
|
|
try: |
|
|
problem_id = problem.get('task_id', 'unknown') |
|
|
problem_id_safe = problem_id.replace('/', '_') |
|
|
|
|
|
|
|
|
if round_num is not None: |
|
|
|
|
|
round_summary_file = os.path.join(self.logger.log_dir, f"{problem_id_safe}_round_{round_num}_summary.json") |
|
|
self._save_single_summary(problem, baseline_results, task_evaluations, round_summary_file, round_num) |
|
|
|
|
|
|
|
|
summary_file = os.path.join(self.logger.log_dir.parent, f"{problem_id_safe}_summary.json") |
|
|
self._save_single_summary(problem, baseline_results, task_evaluations, summary_file) |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_warning(f"Failed to save task summary: {e}") |
|
|
|
|
|
def _save_single_summary(self, problem: Dict[str, Any], baseline_results: Dict[str, Any], |
|
|
task_evaluations: Dict[str, List[Dict[str, Any]]], summary_file: str, round_num: int = None): |
|
|
"""๋จ์ผ summary ํ์ผ ์ ์ฅ""" |
|
|
|
|
|
with open(summary_file, 'w', encoding='utf-8') as f: |
|
|
problem_id = problem.get('task_id', 'unknown') |
|
|
|
|
|
summary = { |
|
|
'problem_id': problem_id, |
|
|
'benchmark': problem.get('benchmark_name', 'unknown'), |
|
|
'success': True, |
|
|
'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S'), |
|
|
'initial_solution_correct': False, |
|
|
'ipo_extraction_success': True, |
|
|
'reasoning_task_results': {} |
|
|
} |
|
|
|
|
|
|
|
|
if round_num is not None: |
|
|
summary['round'] = round_num |
|
|
|
|
|
|
|
|
if baseline_results.get('success_count', 0) > 0: |
|
|
summary['initial_solution_correct'] = True |
|
|
|
|
|
|
|
|
for task_type, evaluations in task_evaluations.items(): |
|
|
if evaluations: |
|
|
correct_count = sum(1 for eval_data in evaluations if eval_data.get('basic_accuracy', 0) > 0.5) |
|
|
total_count = len(evaluations) |
|
|
|
|
|
summary['reasoning_task_results'][task_type] = { |
|
|
'correct': correct_count, |
|
|
'total': total_count, |
|
|
'accuracy': correct_count / total_count if total_count > 0 else 0 |
|
|
} |
|
|
|
|
|
json.dump(summary, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
self.logger.log_info(f"๐ Summary saved: {summary_file}") |
|
|
|
|
|
def _generate_diverse_programs_and_ipo(self, problem: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""๋ค์ํ ํ๋ก๊ทธ๋จ ์์ฑ ๋ฐ IPO ์ถ์ถ (VLLM ๋ฐฐ์น ์ฒ๋ฆฌ)""" |
|
|
|
|
|
|
|
|
return self._generate_programs_batch_vllm(problem) |
|
|
|
|
|
def _generate_programs_batch_vllm(self, problem: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""VLLM ๋ฐฐ์น ์ฒ๋ฆฌ๋ก ํ๋ก๊ทธ๋จ ์์ฑ""" |
|
|
|
|
|
problem_id = problem.get('task_id', 'unknown') |
|
|
batch_size = getattr(self.config, 'parallel_batch_size', 4) |
|
|
|
|
|
self.logger.log_info(f"๐จ Generating {self.config.num_program_variations} diverse programs for {problem_id} (BATCH)") |
|
|
self.logger.log_info(f"๐ Using batch size: {batch_size} (concurrent prompts)") |
|
|
|
|
|
diverse_results = { |
|
|
'success': False, |
|
|
'total_programs': self.config.num_program_variations, |
|
|
'valid_programs': 0, |
|
|
'programs': [], |
|
|
'total_ipo_triples': 0, |
|
|
'error': None, |
|
|
'batch_processing': True, |
|
|
'batch_size': batch_size |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
all_programs = [] |
|
|
|
|
|
for batch_idx, batch_start in enumerate(range(0, self.config.num_program_variations, batch_size)): |
|
|
batch_end = min(batch_start + batch_size, self.config.num_program_variations) |
|
|
batch_ids = list(range(batch_start, batch_end)) |
|
|
|
|
|
self.logger.log_info(f" ๐ฏ Processing batch {batch_idx + 1}: programs {batch_start}-{batch_end-1}") |
|
|
|
|
|
|
|
|
batch_prompts = [] |
|
|
for variation_id in batch_ids: |
|
|
prompt = self._create_diverse_generation_prompt(problem, variation_id) |
|
|
batch_prompts.append(prompt) |
|
|
|
|
|
|
|
|
batch_solutions = self.solution_generator.generate_batch( |
|
|
batch_prompts, |
|
|
temperature=self.config.diverse_generation_temperature |
|
|
) |
|
|
|
|
|
self.logger.log_info(f" ๐ Generated {len(batch_solutions)} solutions") |
|
|
|
|
|
|
|
|
batch_program_results = [] |
|
|
for i, (variation_id, solution) in enumerate(zip(batch_ids, batch_solutions)): |
|
|
program_result = self._process_single_program_basic(problem, solution, variation_id) |
|
|
batch_program_results.append(program_result) |
|
|
|
|
|
|
|
|
successful_programs = [p for p in batch_program_results if p.get('success', False)] |
|
|
|
|
|
self.logger.log_info(f" ๐ Batch results: {len(batch_program_results)} programs, {len(successful_programs)} successful") |
|
|
for i, prog in enumerate(batch_program_results): |
|
|
self.logger.log_info(f" Program {i}: success={prog.get('success')}, IPO triples={prog.get('num_ipo_triples', 0)}") |
|
|
|
|
|
if successful_programs: |
|
|
self.logger.log_info(f" ๐ฒ Generating inputs for {len(successful_programs)} valid programs (BATCH)") |
|
|
|
|
|
|
|
|
input_generation_pairs = [] |
|
|
for program_result in successful_programs: |
|
|
for round_num in range(getattr(self.config, 'input_generation_rounds', 3)): |
|
|
|
|
|
existing_examples = [(triple['full_input_str'], triple['actual_output']) |
|
|
for triple in program_result['ipo_triples']] |
|
|
|
|
|
|
|
|
for prev_input in program_result.get('all_generated_inputs', []): |
|
|
if 'input_args' in prev_input and 'expected_output' in prev_input: |
|
|
existing_examples.append(( |
|
|
str(prev_input['input_args']), |
|
|
str(prev_input['expected_output']) |
|
|
)) |
|
|
|
|
|
input_generation_pairs.append({ |
|
|
'problem': problem, |
|
|
'solution': program_result['extracted_function_code'], |
|
|
'existing_examples': existing_examples, |
|
|
'program_result': program_result, |
|
|
'round_num': round_num |
|
|
}) |
|
|
|
|
|
|
|
|
if input_generation_pairs: |
|
|
self.logger.log_info(f" ๐ Total input generation pairs: {len(input_generation_pairs)}") |
|
|
batch_input_results, batch_generation_info = self.ipo_extractor.generate_diverse_inputs_batch(input_generation_pairs) |
|
|
self.logger.log_info(f" ๐ Batch input results: {len(batch_input_results)} responses") |
|
|
|
|
|
|
|
|
pair_idx = 0 |
|
|
for program_result in successful_programs: |
|
|
program_result['all_generated_inputs'] = [] |
|
|
program_result['input_generation_info'] = [] |
|
|
input_generation_rounds = getattr(self.config, 'input_generation_rounds', 3) |
|
|
|
|
|
for round_num in range(input_generation_rounds): |
|
|
if pair_idx < len(batch_input_results): |
|
|
round_inputs = batch_input_results[pair_idx] |
|
|
program_result['all_generated_inputs'].extend(round_inputs) |
|
|
|
|
|
|
|
|
if pair_idx < len(batch_generation_info) and batch_generation_info[pair_idx]: |
|
|
program_result['input_generation_info'].append(batch_generation_info[pair_idx]) |
|
|
|
|
|
|
|
|
for new_input in round_inputs: |
|
|
new_triple = self.ipo_extractor.create_ipo_from_input( |
|
|
problem, program_result['extracted_function_code'], new_input |
|
|
) |
|
|
if new_triple: |
|
|
program_result['ipo_triples'].append(new_triple) |
|
|
|
|
|
pair_idx += 1 |
|
|
|
|
|
|
|
|
program_result['num_generated_inputs'] = len(program_result['all_generated_inputs']) |
|
|
program_result['num_ipo_triples'] = len(program_result['ipo_triples']) |
|
|
program_result['input_generation_rounds'] = input_generation_rounds |
|
|
|
|
|
|
|
|
for prog_idx, program_result in enumerate(batch_program_results): |
|
|
all_programs.append(program_result) |
|
|
|
|
|
if program_result.get('success', False): |
|
|
diverse_results['valid_programs'] += 1 |
|
|
diverse_results['total_ipo_triples'] += program_result.get('num_ipo_triples', 0) |
|
|
|
|
|
|
|
|
program_id = f'program_{batch_idx * batch_size + prog_idx}' |
|
|
for ipo_idx, triple in enumerate(program_result.get('ipo_triples', [])): |
|
|
|
|
|
triple['source_program_id'] = program_id |
|
|
triple['ipo_index'] = ipo_idx |
|
|
self.ipo_buffer.add(problem_id, triple) |
|
|
|
|
|
if program_result.get('ipo_triples'): |
|
|
self.logger.log_info(f" ๐ฅ Added {len(program_result['ipo_triples'])} IPO triples to buffer from {program_id}") |
|
|
|
|
|
diverse_results['programs'] = all_programs |
|
|
diverse_results['success'] = diverse_results['valid_programs'] > 0 |
|
|
|
|
|
self.logger.log_info(f"โ
Batch processing completed:") |
|
|
self.logger.log_info(f" - Valid programs: {diverse_results['valid_programs']}/{diverse_results['total_programs']}") |
|
|
self.logger.log_info(f" - Total IPO triples: {diverse_results['total_ipo_triples']}") |
|
|
|
|
|
|
|
|
self._save_diverse_programs_evaluation(problem, diverse_results) |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"Batch processing failed: {e}") |
|
|
diverse_results['error'] = str(e) |
|
|
|
|
|
self.logger.log_info("๐ Falling back to sequential processing") |
|
|
return self._generate_programs_sequential(problem) |
|
|
|
|
|
return diverse_results |
|
|
|
|
|
def _create_diverse_generation_prompt(self, problem: Dict[str, Any], variation_id: int) -> str: |
|
|
"""๋ค์ํ ํ๋ก๊ทธ๋จ ์์ฑ์ฉ ํ๋กฌํํธ ์์ฑ""" |
|
|
|
|
|
|
|
|
problem_description = problem.get('prompt', '') |
|
|
if not problem_description: |
|
|
problem_description = problem.get('description', '') |
|
|
|
|
|
|
|
|
diversity_prompts = [ |
|
|
"Please generate a complete, self-contained Python script that solves the following problem.", |
|
|
"Write a Python solution for this problem using a different approach.", |
|
|
"Create an alternative Python implementation for the given problem.", |
|
|
"Solve this problem with a unique Python solution approach." |
|
|
] |
|
|
|
|
|
base_prompt = diversity_prompts[variation_id % len(diversity_prompts)] |
|
|
|
|
|
prompt = f"""{base_prompt} |
|
|
|
|
|
Problem statement: |
|
|
\"\"\" |
|
|
{problem_description} |
|
|
\"\"\" |
|
|
|
|
|
Please provide a complete solution with proper function implementation.""" |
|
|
|
|
|
return prompt |
|
|
|
|
|
def _process_single_program(self, problem: Dict[str, Any], solution: str, variation_id: int) -> Dict[str, Any]: |
|
|
"""๋จ์ผ ํ๋ก๊ทธ๋จ ์ฒ๋ฆฌ (๊ฒ์ฆ + IPO ์ถ์ถ)""" |
|
|
|
|
|
|
|
|
is_valid, syntax_error = self.solution_generator.validate_syntax(solution) |
|
|
|
|
|
program_result = { |
|
|
'variation_id': variation_id, |
|
|
'solution': solution, |
|
|
'syntax_valid': is_valid, |
|
|
'syntax_error': syntax_error, |
|
|
'ipo_triples': [], |
|
|
'num_ipo_triples': 0, |
|
|
'generated_inputs': [], |
|
|
'num_generated_inputs': 0, |
|
|
'input_generation_rounds': 0, |
|
|
'success': False |
|
|
} |
|
|
|
|
|
if is_valid: |
|
|
try: |
|
|
|
|
|
extracted_function_code = self.solution_generator._extract_function_code(solution) |
|
|
ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code) |
|
|
|
|
|
if ipo_triples: |
|
|
program_result['ipo_triples'] = ipo_triples |
|
|
program_result['num_ipo_triples'] = len(ipo_triples) |
|
|
|
|
|
|
|
|
all_generated_inputs = [] |
|
|
input_generation_rounds = getattr(self.config, 'input_generation_rounds', 3) |
|
|
|
|
|
for round_num in range(input_generation_rounds): |
|
|
|
|
|
existing_examples = [(triple['full_input_str'], triple['actual_output']) |
|
|
for triple in ipo_triples] |
|
|
|
|
|
|
|
|
for prev_input in all_generated_inputs: |
|
|
if 'input_args' in prev_input and 'expected_output' in prev_input: |
|
|
existing_examples.append(( |
|
|
str(prev_input['input_args']), |
|
|
str(prev_input['expected_output']) |
|
|
)) |
|
|
|
|
|
|
|
|
diverse_inputs = self.ipo_extractor.generate_diverse_inputs( |
|
|
problem, extracted_function_code, existing_examples |
|
|
) |
|
|
|
|
|
if diverse_inputs: |
|
|
|
|
|
for new_input in diverse_inputs: |
|
|
new_triple = self.ipo_extractor.create_ipo_from_input( |
|
|
problem, extracted_function_code, new_input |
|
|
) |
|
|
if new_triple: |
|
|
ipo_triples.append(new_triple) |
|
|
|
|
|
all_generated_inputs.extend(diverse_inputs) |
|
|
|
|
|
program_result['generated_inputs'] = all_generated_inputs |
|
|
program_result['num_generated_inputs'] = len(all_generated_inputs) |
|
|
program_result['ipo_triples'] = ipo_triples |
|
|
program_result['num_ipo_triples'] = len(ipo_triples) |
|
|
program_result['input_generation_rounds'] = input_generation_rounds |
|
|
program_result['success'] = True |
|
|
|
|
|
except Exception as e: |
|
|
program_result['error'] = str(e) |
|
|
self.logger.log_error(f"IPO extraction failed for variation {variation_id}: {e}") |
|
|
|
|
|
return program_result |
|
|
|
|
|
def _process_single_program_basic(self, problem: Dict[str, Any], solution: str, variation_id: int) -> Dict[str, Any]: |
|
|
"""๋จ์ผ ํ๋ก๊ทธ๋จ ๊ธฐ๋ณธ ์ฒ๋ฆฌ (IPO ์ถ์ถ๋ง, input generation ์ ์ธ)""" |
|
|
|
|
|
|
|
|
is_valid, syntax_error = self.solution_generator.validate_syntax(solution) |
|
|
|
|
|
program_result = { |
|
|
'variation_id': variation_id, |
|
|
'solution': solution, |
|
|
'syntax_valid': is_valid, |
|
|
'syntax_error': syntax_error, |
|
|
'ipo_triples': [], |
|
|
'num_ipo_triples': 0, |
|
|
'all_generated_inputs': [], |
|
|
'num_generated_inputs': 0, |
|
|
'input_generation_rounds': 0, |
|
|
'extracted_function_code': None, |
|
|
'success': False |
|
|
} |
|
|
|
|
|
if is_valid: |
|
|
try: |
|
|
|
|
|
extracted_function_code = self.solution_generator._extract_function_code(solution) |
|
|
program_result['extracted_function_code'] = extracted_function_code |
|
|
ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code) |
|
|
|
|
|
if ipo_triples: |
|
|
program_result['ipo_triples'] = ipo_triples |
|
|
program_result['num_ipo_triples'] = len(ipo_triples) |
|
|
program_result['success'] = True |
|
|
|
|
|
except Exception as e: |
|
|
program_result['error'] = str(e) |
|
|
self.logger.log_error(f"IPO extraction failed for variation {variation_id}: {e}") |
|
|
|
|
|
return program_result |
|
|
|
|
|
|
|
|
def _generate_programs_sequential(self, problem: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""๋ค์ํ ํ๋ก๊ทธ๋จ ์์ฑ ๋ฐ ๊ฐ๊ฐ์์ IPO ์ถ์ถ""" |
|
|
|
|
|
self.logger.log_info(f"๐จ Generating {self.config.num_program_variations} diverse programs for {problem.get('task_id', 'unknown')}") |
|
|
|
|
|
diverse_results = { |
|
|
'success': False, |
|
|
'total_programs': self.config.num_program_variations, |
|
|
'valid_programs': 0, |
|
|
'programs': [], |
|
|
'total_ipo_triples': 0, |
|
|
'error': None |
|
|
} |
|
|
|
|
|
try: |
|
|
for variation_id in range(self.config.num_program_variations): |
|
|
self.logger.log_info(f" ๐ฏ Generating program variation {variation_id + 1}/{self.config.num_program_variations}") |
|
|
|
|
|
|
|
|
diverse_solution = self.solution_generator.generate_diverse( |
|
|
problem, |
|
|
temperature=self.config.diverse_generation_temperature, |
|
|
variation_id=variation_id |
|
|
) |
|
|
|
|
|
|
|
|
is_valid, syntax_error = self.solution_generator.validate_syntax(diverse_solution) |
|
|
|
|
|
program_result = { |
|
|
'variation_id': variation_id, |
|
|
'solution': diverse_solution, |
|
|
'syntax_valid': is_valid, |
|
|
'syntax_error': syntax_error, |
|
|
'ipo_triples': [], |
|
|
'num_ipo_triples': 0, |
|
|
'generated_inputs': [], |
|
|
'num_generated_inputs': 0 |
|
|
} |
|
|
|
|
|
if is_valid: |
|
|
diverse_results['valid_programs'] += 1 |
|
|
|
|
|
|
|
|
extracted_function_code = self.solution_generator._extract_function_code(diverse_solution) |
|
|
ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code) |
|
|
|
|
|
if ipo_triples: |
|
|
program_result['ipo_triples'] = ipo_triples |
|
|
program_result['num_ipo_triples'] = len(ipo_triples) |
|
|
|
|
|
|
|
|
all_generated_inputs = [] |
|
|
input_generation_rounds = getattr(self.config, 'input_generation_rounds', 3) |
|
|
|
|
|
for round_num in range(input_generation_rounds): |
|
|
self.logger.log_info(f" ๐ฏ Input generation round {round_num + 1}/{input_generation_rounds}") |
|
|
|
|
|
|
|
|
existing_examples = [(triple['full_input_str'], triple['actual_output']) |
|
|
for triple in ipo_triples] |
|
|
|
|
|
|
|
|
for prev_input in all_generated_inputs: |
|
|
if 'input_args' in prev_input and 'expected_output' in prev_input: |
|
|
existing_examples.append(( |
|
|
str(prev_input['input_args']), |
|
|
str(prev_input['expected_output']) |
|
|
)) |
|
|
|
|
|
|
|
|
diverse_inputs = self.ipo_extractor.generate_diverse_inputs( |
|
|
problem, extracted_function_code, existing_examples |
|
|
) |
|
|
|
|
|
if not diverse_inputs: |
|
|
self.logger.log_warning(f" โ ๏ธ Round {round_num + 1}: No valid inputs generated") |
|
|
continue |
|
|
|
|
|
self.logger.log_info(f" โ
Round {round_num + 1}: Generated {len(diverse_inputs)} new inputs") |
|
|
|
|
|
|
|
|
if round_num == 0 and hasattr(self.ipo_extractor, 'last_input_generation_info'): |
|
|
self._save_input_generation_details( |
|
|
problem, |
|
|
variation_id + 1, |
|
|
self.ipo_extractor.last_input_generation_info |
|
|
) |
|
|
|
|
|
|
|
|
round_ipo_count = 0 |
|
|
for new_input in diverse_inputs: |
|
|
new_triple = self.ipo_extractor.create_ipo_from_input( |
|
|
problem, extracted_function_code, new_input |
|
|
) |
|
|
if new_triple: |
|
|
ipo_triples.append(new_triple) |
|
|
round_ipo_count += 1 |
|
|
|
|
|
self.logger.log_info(f" ๐ Round {round_num + 1}: Created {round_ipo_count} IPO triples") |
|
|
all_generated_inputs.extend(diverse_inputs) |
|
|
|
|
|
program_result['generated_inputs'] = all_generated_inputs |
|
|
program_result['input_generation_rounds'] = input_generation_rounds |
|
|
program_result['num_generated_inputs'] = len(diverse_inputs) |
|
|
program_result['num_ipo_triples'] = len(ipo_triples) |
|
|
|
|
|
|
|
|
if hasattr(self.ipo_extractor, 'last_input_generation_info'): |
|
|
program_result['input_generation_info'] = self.ipo_extractor.last_input_generation_info |
|
|
|
|
|
|
|
|
problem_id = problem.get('task_id', 'unknown') |
|
|
program_id = f'program_{variation_id}' |
|
|
for ipo_idx, triple in enumerate(ipo_triples): |
|
|
|
|
|
triple['source_program_id'] = program_id |
|
|
triple['ipo_index'] = ipo_idx |
|
|
self.ipo_buffer.add(problem_id, triple) |
|
|
|
|
|
diverse_results['total_ipo_triples'] += len(ipo_triples) |
|
|
|
|
|
self.logger.log_info(f" โ
Program {variation_id + 1}: {len(ipo_triples)} IPO triples generated") |
|
|
else: |
|
|
self.logger.log_warning(f" โ ๏ธ Program {variation_id + 1}: No IPO triples extracted") |
|
|
else: |
|
|
self.logger.log_warning(f" โ Program {variation_id + 1}: Syntax error - {syntax_error}") |
|
|
|
|
|
diverse_results['programs'].append(program_result) |
|
|
|
|
|
|
|
|
diverse_results['success'] = diverse_results['valid_programs'] > 0 |
|
|
|
|
|
self.logger.log_info(f" ๐ Diverse programs: {diverse_results['valid_programs']}/{diverse_results['total_programs']} valid, {diverse_results['total_ipo_triples']} total IPO triples") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"โ Diverse program generation failed: {e}") |
|
|
diverse_results['error'] = str(e) |
|
|
|
|
|
return diverse_results |
|
|
|
|
|
def _save_input_generation_details(self, problem: Dict[str, Any], program_id: int, |
|
|
input_gen_info: Dict[str, Any]) -> None: |
|
|
"""Input generation ์์ธ ์ ๋ณด๋ฅผ ํ์ผ๋ก ์ ์ฅ""" |
|
|
try: |
|
|
problem_id = problem.get('task_id', 'unknown') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(self, '_input_generation_infos'): |
|
|
self._input_generation_infos = {} |
|
|
|
|
|
key = f"{problem_id}_program_{program_id}" |
|
|
self._input_generation_infos[key] = input_gen_info |
|
|
|
|
|
self.logger.log_info(f"Input generation info collected for {problem_id} program {program_id}") |
|
|
return |
|
|
|
|
|
|
|
|
detail_file = input_gen_dir / f"program_{program_id}_details.txt" |
|
|
|
|
|
with open(detail_file, 'w', encoding='utf-8') as f: |
|
|
f.write(f"Input Generation Details\n") |
|
|
f.write(f"Problem ID: {problem_id}\n") |
|
|
f.write(f"Program ID: {program_id}\n") |
|
|
f.write(f"Generated: {self.timestamp}\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
|
|
|
f.write("1. FUNCTION INFO:\n") |
|
|
f.write("=" * 80 + "\n") |
|
|
func_info = input_gen_info.get('function_info', {}) |
|
|
f.write(f"Function Name: {func_info.get('name', 'N/A')}\n") |
|
|
f.write(f"Parameters: {func_info.get('args', 'N/A')}\n") |
|
|
f.write(f"Return Type: {func_info.get('return_type', 'N/A')}\n\n") |
|
|
|
|
|
f.write("2. ARGUMENT TYPE INFO:\n") |
|
|
f.write("=" * 80 + "\n") |
|
|
f.write(input_gen_info.get('arg_type_info', 'N/A') + "\n\n") |
|
|
|
|
|
f.write("3. EXISTING EXAMPLES:\n") |
|
|
f.write("=" * 80 + "\n") |
|
|
for i, (inp, out) in enumerate(input_gen_info.get('existing_examples', [])): |
|
|
f.write(f"Example {i+1}: Input: {inp} โ Output: {out}\n") |
|
|
f.write("\n") |
|
|
|
|
|
f.write("4. LLM PROMPT:\n") |
|
|
f.write("=" * 80 + "\n") |
|
|
f.write(input_gen_info.get('prompt', 'N/A') + "\n\n") |
|
|
|
|
|
f.write("5. LLM RESPONSE:\n") |
|
|
f.write("=" * 80 + "\n") |
|
|
f.write(input_gen_info.get('llm_response', 'N/A') + "\n\n") |
|
|
|
|
|
f.write("6. EXTRACTED INPUTS:\n") |
|
|
f.write("=" * 80 + "\n") |
|
|
extracted = input_gen_info.get('extracted_inputs', []) |
|
|
if extracted: |
|
|
for i, inp_data in enumerate(extracted): |
|
|
f.write(f"Input {i+1}: {inp_data}\n") |
|
|
else: |
|
|
f.write("No inputs extracted\n") |
|
|
|
|
|
|
|
|
summary_file = input_gen_dir / "input_generation_summary.txt" |
|
|
mode = 'a' if summary_file.exists() else 'w' |
|
|
|
|
|
with open(summary_file, mode, encoding='utf-8') as f: |
|
|
if mode == 'w': |
|
|
f.write(f"Input Generation Summary\n") |
|
|
f.write(f"Problem ID: {problem_id}\n") |
|
|
f.write(f"Generated: {self.timestamp}\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
|
|
|
f.write(f"Program {program_id}: {len(extracted)} inputs generated\n") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_warning(f"Failed to save input generation details: {e}") |
|
|
|
|
|
def _save_diverse_programs_evaluation(self, problem: Dict[str, Any], |
|
|
diverse_results: Dict[str, Any]) -> None: |
|
|
"""๋ค์ํ ํ๋ก๊ทธ๋จ๋ค์ ํ๊ฐ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅ (batch evaluation๊ณผ ๋์ผํ ํ์)""" |
|
|
try: |
|
|
problem_id = problem.get('task_id', 'unknown') |
|
|
|
|
|
|
|
|
diverse_dir = os.path.join(self.logger.log_dir, "diverse_programs") |
|
|
os.makedirs(diverse_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
summary_file = os.path.join(diverse_dir, "diverse_summary.txt") |
|
|
with open(summary_file, 'w', encoding='utf-8') as f: |
|
|
f.write(f"Diverse Programs Summary\n") |
|
|
f.write(f"Problem ID: {problem_id}\n") |
|
|
f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") |
|
|
f.write(f"Total Programs: {len(diverse_results.get('programs', []))}\n") |
|
|
f.write("="*50 + "\n\n") |
|
|
|
|
|
for i, program in enumerate(diverse_results.get('programs', []), 1): |
|
|
f.write(f"Program {i}: {'โ
Valid' if program.get('syntax_valid', False) else 'โ Invalid'}\n") |
|
|
f.write(f"IPO Triples: {program.get('num_ipo_triples', 0)}\n") |
|
|
f.write(f"Generated Inputs: {program.get('num_generated_inputs', 0)}\n\n") |
|
|
|
|
|
|
|
|
for i, program in enumerate(diverse_results.get('programs', []), 1): |
|
|
program_dir = os.path.join(diverse_dir, f"program_{i}") |
|
|
os.makedirs(program_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
details_file = os.path.join(program_dir, "generation_details.txt") |
|
|
with open(details_file, 'w', encoding='utf-8') as f: |
|
|
f.write(f"Diverse Program {i} - Generation Details\n") |
|
|
f.write(f"Problem ID: {problem_id}\n") |
|
|
f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
|
|
|
f.write("1. ORIGINAL PROBLEM:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(problem.get('prompt', 'No prompt available') + "\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
|
|
|
f.write("2. DIVERSITY PROMPT USED:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(program.get('diversity_instruction', 'Standard generation') + "\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
|
|
|
f.write("3. LLM RESPONSE:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(program.get('solution', 'N/A') + "\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
|
|
|
f.write("4. EVALUATION RESULTS:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(f"Syntax Valid: {'โ
YES' if program.get('syntax_valid', False) else 'โ NO'}\n") |
|
|
f.write(f"IPO Triples Generated: {program.get('num_ipo_triples', 0)}\n") |
|
|
f.write(f"Input Generation: {program.get('num_generated_inputs', 0)} new inputs\n") |
|
|
f.write("="*80 + "\n") |
|
|
|
|
|
|
|
|
solution_file = os.path.join(program_dir, "solution.py") |
|
|
with open(solution_file, 'w', encoding='utf-8') as f: |
|
|
f.write(f"# Diverse Program {i}\n") |
|
|
f.write(f"# Problem ID: {problem_id}\n") |
|
|
f.write(f"# Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") |
|
|
f.write(f"# Syntax Valid: {program.get('syntax_valid', False)}\n") |
|
|
f.write(f"# IPO Triples: {program.get('num_ipo_triples', 0)}\n\n") |
|
|
|
|
|
|
|
|
|
|
|
f.write(program.get('solution', '# No solution available')) |
|
|
|
|
|
|
|
|
if program.get('ipo_triples'): |
|
|
ipo_dir = os.path.join(program_dir, "ipo_triples") |
|
|
os.makedirs(ipo_dir, exist_ok=True) |
|
|
|
|
|
for j, triple in enumerate(program['ipo_triples'], 1): |
|
|
triple_file = os.path.join(ipo_dir, f"triple_{j}.json") |
|
|
with open(triple_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(triple, f, indent=2) |
|
|
|
|
|
|
|
|
if program.get('input_generation_info'): |
|
|
details_file = os.path.join(program_dir, "input_generation_details.txt") |
|
|
with open(details_file, 'w', encoding='utf-8') as f: |
|
|
f.write(f"Input Generation Details - Program {i}\n") |
|
|
f.write(f"Problem ID: {problem_id}\n") |
|
|
f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
|
|
|
|
|
|
for round_idx, gen_info in enumerate(program['input_generation_info'], 1): |
|
|
if 'error' in gen_info: |
|
|
f.write(f"ROUND {round_idx} - ERROR:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(f"Error: {gen_info.get('error', 'Unknown error')}\n") |
|
|
f.write(f"Traceback:\n{gen_info.get('traceback', 'No traceback')}\n") |
|
|
f.write("\n") |
|
|
continue |
|
|
|
|
|
f.write(f"ROUND {round_idx}:\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
|
|
|
|
|
|
func_info = gen_info.get('function_info', {}) |
|
|
f.write("1. FUNCTION INFO:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(f"Function Name: {func_info.get('name', 'N/A')}\n") |
|
|
f.write(f"Parameters: {func_info.get('args', 'N/A')}\n") |
|
|
f.write(f"Parameters String: {func_info.get('signature', 'N/A')}\n\n") |
|
|
|
|
|
|
|
|
f.write("2. ARGUMENT TYPE INFO:\n") |
|
|
f.write("="*80 + "\n") |
|
|
arg_types = gen_info.get('arg_type_info', {}) |
|
|
if arg_types: |
|
|
f.write("Argument types:\n") |
|
|
for arg, arg_type in arg_types.items(): |
|
|
f.write(f"- {arg}: {arg_type}\n") |
|
|
else: |
|
|
f.write("No argument type information available\n") |
|
|
f.write("\n") |
|
|
|
|
|
|
|
|
f.write("3. EXISTING EXAMPLES:\n") |
|
|
f.write("="*80 + "\n") |
|
|
existing_examples = gen_info.get('existing_examples', []) |
|
|
if existing_examples: |
|
|
for idx, example in enumerate(existing_examples, 1): |
|
|
f.write(f"Example {idx}: Input: {example[0]} โ Output: {example[1]}\n") |
|
|
else: |
|
|
f.write("No existing examples\n") |
|
|
f.write("\n") |
|
|
|
|
|
|
|
|
f.write("4. LLM PROMPT:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(gen_info.get('prompt', 'No prompt available')) |
|
|
f.write("\n"*2 + "="*80 + "\n\n") |
|
|
|
|
|
|
|
|
f.write("5. LLM RESPONSE:\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(gen_info.get('llm_response', 'No response available')) |
|
|
f.write("\n"*2 + "="*80 + "\n\n") |
|
|
|
|
|
|
|
|
f.write("6. EXTRACTED INPUTS:\n") |
|
|
f.write("="*80 + "\n") |
|
|
extracted = gen_info.get('extracted_inputs', []) |
|
|
if extracted: |
|
|
for idx, inp in enumerate(extracted, 1): |
|
|
f.write(f"Input {idx}: {inp}\n") |
|
|
else: |
|
|
f.write("No inputs extracted\n") |
|
|
|
|
|
|
|
|
valid = gen_info.get('valid_inputs', []) |
|
|
if valid != extracted: |
|
|
f.write("\n7. VALID INPUTS (after validation):\n") |
|
|
f.write("="*80 + "\n") |
|
|
if valid: |
|
|
for idx, inp in enumerate(valid, 1): |
|
|
f.write(f"Input {idx}: {inp}\n") |
|
|
else: |
|
|
f.write("No valid inputs after validation\n") |
|
|
|
|
|
f.write("\n") |
|
|
|
|
|
self.logger.log_info(f"๐พ Diverse programs saved to {diverse_dir} (batch evaluation format)") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"Failed to save diverse programs evaluation: {e}") |
|
|
|
|
|
def _save_azr_training_data(self, all_tasks: Dict[str, List[Dict[str, Any]]], |
|
|
problem_id: str, round_num: int, |
|
|
output_dir: str) -> Dict[str, str]: |
|
|
"""AZR ํ์ต์ฉ ๋ฐ์ดํฐ๋ฅผ parquet ํ์์ผ๋ก ์ ์ฅ""" |
|
|
|
|
|
try: |
|
|
import pandas as pd |
|
|
import os |
|
|
|
|
|
|
|
|
azr_dir = os.path.join(output_dir, 'azr_training_data') |
|
|
os.makedirs(azr_dir, exist_ok=True) |
|
|
|
|
|
saved_files = {} |
|
|
total_tasks = 0 |
|
|
|
|
|
|
|
|
for task_type, tasks in all_tasks.items(): |
|
|
if not tasks: |
|
|
continue |
|
|
|
|
|
|
|
|
azr_data = [] |
|
|
for task in tasks: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_dict_list = [{"role": "user", "content": task['prompt']}] |
|
|
print(f"[DEBUG AZR DATA SAVE] Converting prompt to dict list format") |
|
|
print(f"[DEBUG] Original prompt type: {type(task['prompt'])}, length: {len(task['prompt']) if isinstance(task['prompt'], str) else 'N/A'}") |
|
|
print(f"[DEBUG] Converted prompt type: {type(prompt_dict_list)}, first elem: {type(prompt_dict_list[0])}") |
|
|
azr_record = { |
|
|
'prompt': prompt_dict_list, |
|
|
'uid': task['uid'], |
|
|
'ipo_group_id': task['ipo_group_id'], |
|
|
'source_program_id': task['source_program_id'], |
|
|
'ipo_index': task['ipo_index'], |
|
|
'problem': { |
|
|
'input': task['ipo_triple']['input'], |
|
|
'output': task['ipo_triple']['output'], |
|
|
'snippet': task['ipo_triple']['program'] |
|
|
}, |
|
|
'ground_truth': task['ground_truth'], |
|
|
'extra_info': task['extra_info'], |
|
|
'basic_accuracy': task['basic_accuracy'], |
|
|
'original_problem_id': task['original_problem_id'], |
|
|
'round': task['round'] |
|
|
} |
|
|
azr_data.append(azr_record) |
|
|
|
|
|
|
|
|
azr_data.sort(key=lambda x: x['ipo_group_id']) |
|
|
|
|
|
|
|
|
df = pd.DataFrame(azr_data) |
|
|
file_path = os.path.join(azr_dir, f'{task_type}.parquet') |
|
|
df.to_parquet(file_path, index=False) |
|
|
|
|
|
|
|
|
print(f"[DEBUG] Saved {task_type}.parquet with {len(df)} records") |
|
|
if len(df) > 0: |
|
|
saved_prompt = df.iloc[0]['prompt'] |
|
|
print(f"[DEBUG] First saved prompt type: {type(saved_prompt)}") |
|
|
|
|
|
saved_files[task_type] = file_path |
|
|
total_tasks += len(tasks) |
|
|
|
|
|
self.logger.log_info(f"๐พ Saved {len(tasks)} {task_type} tasks to {file_path}") |
|
|
|
|
|
|
|
|
stats = { |
|
|
'problem_id': problem_id, |
|
|
'round': round_num, |
|
|
'total_tasks': total_tasks, |
|
|
'tasks_by_type': {k: len(v) for k, v in all_tasks.items()}, |
|
|
'files': saved_files, |
|
|
'batch_groups': len(set(task['ipo_group_id'] for tasks in all_tasks.values() for task in tasks)) |
|
|
} |
|
|
|
|
|
stats_file = os.path.join(azr_dir, 'training_stats.json') |
|
|
import json |
|
|
with open(stats_file, 'w') as f: |
|
|
json.dump(stats, f, indent=2) |
|
|
|
|
|
self.logger.log_info(f"โ
AZR training data saved: {total_tasks} tasks in {len(saved_files)} files") |
|
|
self.logger.log_info(f"๐ Batch groups: {stats['batch_groups']} (for batch alignment)") |
|
|
|
|
|
return saved_files |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"Failed to save AZR training data: {e}") |
|
|
return {} |