| import os |
| |
| |
| os.environ["VLLM_NO_USAGE_STATS"] = "1" |
| import json |
| import re |
| import sys |
| import gc |
| import random |
| import argparse |
| import traceback |
| from datetime import datetime |
| from typing import List, Dict, Optional, Any |
| from multiprocessing import Process, Queue |
|
|
| import torch |
| import numpy as np |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from smpeft import PeftModel |
| from vllm import LLM, SamplingParams, EngineArgs |
| from datasets import load_dataset |
| import draccus |
| from tqdm import tqdm |
|
|
| try: |
| from inference_math.grader import math_equal |
| except ImportError: |
| raise ValueError("[Warning] 'grader.py' not found. GSM8k evaluation might fail.") |
|
|
| try: |
| from inference_math import util |
| except ImportError: |
| raise ValueError("[Warning] 'util.py' not found. MATH evaluation might fail.") |
|
|
| from .config import MainConfig |
| from .utils import set_seed_all |
|
|
| MAX_NEW_TOKENS = 1536 |
| MAX_INT = sys.maxsize |
|
|
| |
| PROMPT_TEMPLATE = ( |
| "Below is an instruction that describes a task. " |
| "Write a response that appropriately completes the request.\n\n" |
| "### Instruction:\n{instruction}\n\n### Response: Let's think step by step." |
| ) |
|
|
| |
| from fraction import Fraction |
|
|
| def is_number(s): |
| try: |
| float(s) |
| return True |
| except ValueError: |
| pass |
| try: |
| import unicodedata |
| unicodedata.numeric(s) |
| return True |
| except (TypeError, ValueError): |
| pass |
| return False |
|
|
| def extract_answer_number_gsm8k(completion): |
| """ |
| Exact logic from gsm8k_infer.py |
| """ |
| text = completion.split('The answer is: ') |
| if len(text) > 1: |
| extract_ans = text[-1].strip() |
| match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans) |
| if match: |
| if '/' in match.group(): |
| denominator = match.group().split('/')[1] |
| numerator = match.group().split('/')[0] |
| if is_number(denominator) == True and is_number(numerator) == True: |
| if denominator == '0': |
| return round(float(numerator.replace(',', ''))) |
| else: |
| frac = Fraction(match.group().replace(',', '')) |
| num_numerator = frac.numerator |
| num_denominator = frac.denominator |
| return round(float(num_numerator / num_denominator)) |
| else: |
| return None |
| else: |
| if float(match.group().replace(',', '')) == float('inf'): |
| return None |
| return round(float(match.group().replace(',', ''))) |
| else: |
| return None |
| else: |
| return None |
| |
| |
| def remove_boxed(s): |
| """ |
| Extracts content from \boxed{...} |
| """ |
| left = "\\boxed{" |
| try: |
| assert s[:len(left)] == left |
| assert s[-1] == "}" |
| return s[len(left):-1] |
| except: |
| return None |
| |
| def process_results_math(completion, answer): |
| """ |
| Exact logic from MATH_infer.py |
| """ |
| split_ans = completion.split('The answer is: ') |
| if len(split_ans) > 1: |
| ans = split_ans[-1] |
| extract_ans_temp = ans.split('.\n')[0] |
| extract_ans_temp = extract_ans_temp.strip() |
| if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.': |
| extract_ans = extract_ans_temp[0:-1] |
| else: |
| extract_ans = extract_ans_temp |
| extract_ans = extract_ans.strip() |
| |
| if util.is_equiv(extract_ans, answer): |
| return True, extract_ans |
| else: |
| return False, extract_ans |
| else: |
| return False, None |
| |
| |
| def format_prompt(examples): |
| prompts = [] |
| |
| instructions = examples.get('question', examples.get('instruction', [])) |
| |
| for instr in instructions: |
| source_text = PROMPT_TEMPLATE.format(instruction=instr) |
| prompts.append(source_text) |
| |
| return {"prompt": prompts} |
|
|
| |
| def merge_process(queue, mainCfg: MainConfig, force_to_merge: bool = False): |
| """ |
| Handles the PEFT merge process in a separate process to manage VRAM. |
| """ |
| try: |
| model_name = mainCfg.model.model_name |
| |
| |
| if mainCfg.model.merge_adapter_path is not None: |
| adapter = mainCfg.model.merge_adapter_path + "/ft2" |
| print(f'Merging from merge_adapter_path: {adapter}') |
| elif mainCfg.model.adapter_path is not None: |
| adapter = mainCfg.model.adapter_path + "/ft2" |
| print(f'Merging from adapter_path: {adapter}') |
| else: |
| raise KeyError('No adapter path provided in config.') |
| |
| |
| if mainCfg.model.merge_output_path is not None: |
| output_path = os.path.join(mainCfg.model.merge_output_path, "merge") |
| out_json = mainCfg.model.merge_output_path |
| else: |
| output_path = os.path.join(mainCfg.model.adapter_path, "merge") |
| out_json = mainCfg.model.adapter_path |
| |
| |
| if os.path.exists(output_path): |
| has_weights = any(f.endswith(".bin") or f.endswith(".safetensors") for f in os.listdir(output_path)) |
| else: |
| has_weights = False |
|
|
| if not has_weights or force_to_merge: |
| print(f"Loading base model: {model_name}") |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", low_cpu_mem_usage=True) |
| tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='cpu') |
|
|
| print(f"Loading adapter: {adapter}") |
| model = PeftModel.from_pretrained(model, adapter) |
| print("Merging model...") |
| model = model.merge_and_unload() |
| |
| print(f"Saving merged model to: {output_path}") |
| model.save_pretrained(output_path, safe_serialization=True, max_shard_size="10GB") |
| tokenizer.save_pretrained(output_path) |
| |
| del model |
| del tokenizer |
| gc.collect() |
| torch.cuda.empty_cache() |
| print('Merge complete.') |
| else: |
| print("Merged weights found. Skipping merge step.") |
| |
| queue.put((output_path, out_json)) |
| |
| except Exception as e: |
| error_msg = traceback.format_exc() |
| print(error_msg) |
| queue.put(error_msg) |
| print(f"Error in merge_process: {e}") |
|
|
| |
| def score_outputs(outputs, test_target_name, ground_truths, out_json): |
| results = [] |
| total_correct = 0 |
| total_samples = len(ground_truths) |
| invalid_count = 0 |
|
|
| print(f"Calculating scores for {test_target_name}...") |
|
|
| |
| is_gsm8k = 'gsm8k' in test_target_name.lower() |
| is_math = 'math' in test_target_name.lower() |
|
|
| for i, output in enumerate(tqdm(outputs, desc="Scoring")): |
| prediction_text = output.outputs[0].text |
| gt_raw = ground_truths[i] |
| |
| is_correct = False |
| extracted_pred = None |
| clean_gt = None |
|
|
| |
| if is_gsm8k: |
| |
| |
| try: |
| if '####' in str(gt_raw): |
| clean_gt_str = str(gt_raw).split('#### ')[1].replace(',', '').strip() |
| else: |
| clean_gt_str = str(gt_raw).replace(',', '').strip() |
| |
| clean_gt = float(clean_gt_str) |
| except: |
| clean_gt = gt_raw |
|
|
| extracted_pred = extract_answer_number_gsm8k(prediction_text) |
| |
| if extracted_pred is not None: |
| |
| try: |
| is_correct = (float(extracted_pred) == float(clean_gt)) or math_equal(extracted_pred, clean_gt) |
| except: |
| is_correct = False |
| else: |
| is_correct = False |
| invalid_count += 1 |
|
|
| |
| elif is_math: |
| |
| |
| try: |
| clean_gt = remove_boxed(util.last_boxed_only_string(str(gt_raw))) |
| except: |
| clean_gt = gt_raw |
|
|
| |
| is_correct, extracted_pred = process_results_math(prediction_text, clean_gt) |
| |
| if not extracted_pred and not is_correct: |
| invalid_count += 1 |
|
|
| results.append({ |
| "id": i, |
| "prediction_full": prediction_text, |
| "extracted_pred": extracted_pred, |
| "ground_truth_raw": gt_raw, |
| "ground_truth_clean": clean_gt, |
| "is_correct": is_correct, |
| }) |
| |
| if is_correct: |
| total_correct += 1 |
|
|
| avg_acc = 100.0 * total_correct / total_samples if total_samples > 0 else 0 |
| |
| print("\n" + "="*40) |
| print(f"FINAL RESULTS: {test_target_name}") |
| print("="*40) |
| print(f"Total Samples: {total_samples}") |
| print(f"Invalid/No Answer Found: {invalid_count}") |
| print(f"Accuracy: {avg_acc:.2f}%") |
| print("="*40) |
|
|
| os.makedirs(out_json, exist_ok=True) |
| save_file = os.path.join(out_json, f'{test_target_name}.json') |
| with open(save_file, "w", encoding="utf-8") as f: |
| json.dump({ |
| "metrics": { |
| "accuracy": avg_acc, |
| "total": total_samples, |
| "invalid": invalid_count |
| }, |
| "details": results |
| }, f, indent=2, ensure_ascii=False) |
| |
| return avg_acc |
|
|
| @draccus.wrap() |
| def main(mainCfg: MainConfig): |
| print('='*120) |
| set_seed_all(mainCfg.seed) |
| |
| |
| queue = Queue() |
| p = Process(target=merge_process, args=(queue, mainCfg, False)) |
| p.start() |
| merge_result = queue.get() |
| p.join() |
|
|
| if merge_result is None or isinstance(merge_result, str): |
| raise RuntimeError(f"Model merging failed: {merge_result}") |
| |
| model_path, out_json = merge_result |
|
|
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model directory does not exist: {model_path}") |
| |
| print(f"Verified model path: {os.path.abspath(model_path)}") |
| out_json = os.path.join(out_json, "results") |
| print('Output JSON path: ', out_json) |
|
|
| |
| print("Initializing vLLM...") |
| llm = LLM( |
| model=model_path, |
| dtype="bfloat16", |
| gpu_memory_utilization=0.9, |
| max_model_len=mainCfg.infer.infer_max_seq_length, |
| tensor_parallel_size=1, |
| ) |
|
|
| |
| stop_tokens = ["Instruction:", "Instruction", "Response:", "Response",] |
| sampling_params = SamplingParams( |
| temperature=0, |
| top_p=1, |
| max_tokens=MAX_NEW_TOKENS, |
| stop=stop_tokens |
| ) |
| |
| start_time_total = datetime.now() |
| final_res = {} |
| all_task_acc = [] |
|
|
| |
| try: |
| |
| for test_target_name in mainCfg.infer.datasets: |
| print(f"Processing dataset: {test_target_name}") |
| |
| |
| dataset_path = f'./dataset/{test_target_name}/test.jsonl' |
| if not os.path.exists(dataset_path): |
| print(f"[Error] Local file not found: {dataset_path}. Skipping.") |
| continue |
|
|
| print(f"Loading local file: {dataset_path}") |
| test_dataset = load_dataset("json", data_files=dataset_path, split='train') |
|
|
| |
| |
| if 'instruction' in test_dataset.column_names: |
| test_dataset = test_dataset.rename_column('instruction', 'question') |
| if 'output' in test_dataset.column_names: |
| test_dataset = test_dataset.rename_column('output', 'answer') |
|
|
| ground_truths = test_dataset['answer'] |
|
|
| |
| print("Formatting prompts...") |
| test_dataset = test_dataset.map( |
| format_prompt, |
| batched=True, |
| batch_size=1000, |
| desc="Formatting prompts" |
| ) |
| prompts = test_dataset['prompt'] |
|
|
| |
| print(f"Generating responses for {len(prompts)} samples...") |
| start_time_task = datetime.now() |
| |
| outputs = llm.generate(prompts, sampling_params) |
| |
| end_time_task = datetime.now() |
| print(f"Task {test_target_name} duration: {end_time_task - start_time_task}") |
| |
| |
| avg_acc = score_outputs( |
| outputs=outputs, |
| test_target_name=test_target_name, |
| ground_truths=ground_truths, |
| out_json=out_json |
| ) |
| |
| final_res[test_target_name] = avg_acc |
| all_task_acc.append(avg_acc) |
| |
| del prompts |
| del outputs |
| del test_dataset |
| |
|
|
| except Exception as e: |
| print(f"Error during evaluation loop: {e}") |
| traceback.print_exc() |
|
|
| |
| print('Accuracies per task:', all_task_acc) |
| if all_task_acc: |
| avg_score = sum(all_task_acc) / len(all_task_acc) |
| else: |
| avg_score = 0.0 |
|
|
| final_res['average_score'] = avg_score |
| |
| os.makedirs(out_json, exist_ok=True) |
| save_file = os.path.join(out_json, 'FINAL.json') |
| |
| with open(save_file, "w", encoding="utf-8") as f: |
| json.dump(final_res, f, indent=2, ensure_ascii=False) |
| |
| print(f"All Results saved to {save_file}, Overall Score: {avg_score:.2f}") |
|
|
| end_time_total = datetime.now() |
| print(f"Total execution time: {end_time_total - start_time_total}") |
|
|
| if __name__ == "__main__": |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| main() |