File size: 15,420 Bytes
ecadbd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
import os
# os.environ["VLLM_USE_V1"] = "0"
# os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
os.environ["VLLM_NO_USAGE_STATS"] = "1"
import json
import re
import sys
import gc
import random
import argparse
import traceback
from datetime import datetime
from typing import List, Dict, Optional, Any
from multiprocessing import Process, Queue

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from smpeft import PeftModel
from vllm import LLM, SamplingParams, EngineArgs
from datasets import load_dataset
import draccus
from tqdm import tqdm

try:
    from inference_math.grader import math_equal
except ImportError:
    raise ValueError("[Warning] 'grader.py' not found. GSM8k evaluation might fail.")

try:
    from inference_math import util
except ImportError:
    raise ValueError("[Warning] 'util.py' not found. MATH evaluation might fail.")

from .config import MainConfig 
from .utils import set_seed_all

MAX_NEW_TOKENS = 1536
MAX_INT = sys.maxsize

# CoT Prompt Template for Math
PROMPT_TEMPLATE = (
    "Below is an instruction that describes a task. "
    "Write a response that appropriately completes the request.\n\n"
    "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
)

# --- Helper Functions for GSM8k ---
from fraction import Fraction

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        pass
    try:
        import unicodedata
        unicodedata.numeric(s)
        return True
    except (TypeError, ValueError):
        pass
    return False

def extract_answer_number_gsm8k(completion):
    """
    Exact logic from gsm8k_infer.py
    """
    text = completion.split('The answer is: ')
    if len(text) > 1:
        extract_ans = text[-1].strip()
        match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
        if match:
            if '/' in match.group():
                denominator = match.group().split('/')[1]
                numerator = match.group().split('/')[0]
                if is_number(denominator) == True and is_number(numerator) == True:
                    if denominator == '0':
                        return round(float(numerator.replace(',', '')))
                    else:
                        frac = Fraction(match.group().replace(',', ''))
                        num_numerator = frac.numerator
                        num_denominator = frac.denominator
                        return round(float(num_numerator / num_denominator))
                else:
                    return None
            else:
                if float(match.group().replace(',', '')) == float('inf'):
                    return None
                return round(float(match.group().replace(',', '')))
        else:
            return None
    else:
        return None
    
# --- Helper Functions: MATH (Ported strictly from MATH_infer.py) ---
def remove_boxed(s):
    """
    Extracts content from \boxed{...}
    """
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None
    
def process_results_math(completion, answer):
    """
    Exact logic from MATH_infer.py
    """
    split_ans = completion.split('The answer is: ')
    if len(split_ans) > 1:
        ans = split_ans[-1]
        extract_ans_temp = ans.split('.\n')[0]
        extract_ans_temp = extract_ans_temp.strip()
        if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
            extract_ans = extract_ans_temp[0:-1]
        else:
            extract_ans = extract_ans_temp
        extract_ans = extract_ans.strip()
        
        if util.is_equiv(extract_ans, answer):
            return True, extract_ans
        else:
            return False, extract_ans
    else:
        return False, None
    
# --- Prompt Formatting ---
def format_prompt(examples):
    prompts = []
    # GSM8k uses 'question', MATH often uses 'instruction' or 'problem'
    instructions = examples.get('question', examples.get('instruction', []))
    
    for instr in instructions:
        source_text = PROMPT_TEMPLATE.format(instruction=instr)
        prompts.append(source_text)
    
    return {"prompt": prompts}

# --- Core Logic ---
def merge_process(queue, mainCfg: MainConfig, force_to_merge: bool = False):
    """
    Handles the PEFT merge process in a separate process to manage VRAM.
    """
    try:
        model_name = mainCfg.model.model_name
        
        # Determine adapter path
        if mainCfg.model.merge_adapter_path is not None:
            adapter = mainCfg.model.merge_adapter_path + "/ft2" # Adjust subfolder if needed
            print(f'Merging from merge_adapter_path: {adapter}')
        elif mainCfg.model.adapter_path is not None:
            adapter = mainCfg.model.adapter_path + "/ft2"
            print(f'Merging from adapter_path: {adapter}')
        else:
            raise KeyError('No adapter path provided in config.')
        
        # Determine output path
        if mainCfg.model.merge_output_path is not None:
            output_path = os.path.join(mainCfg.model.merge_output_path, "merge")
            out_json = mainCfg.model.merge_output_path
        else:
            output_path = os.path.join(mainCfg.model.adapter_path, "merge")
            out_json = mainCfg.model.adapter_path
        
        # Check if merge is needed
        if os.path.exists(output_path):
            has_weights = any(f.endswith(".bin") or f.endswith(".safetensors") for f in os.listdir(output_path))
        else:
            has_weights = False

        if not has_weights or force_to_merge:
            print(f"Loading base model: {model_name}")
            model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", low_cpu_mem_usage=True)
            tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='cpu')

            print(f"Loading adapter: {adapter}")
            model = PeftModel.from_pretrained(model, adapter)
            print("Merging model...")
            model = model.merge_and_unload()
            
            print(f"Saving merged model to: {output_path}")
            model.save_pretrained(output_path, safe_serialization=True, max_shard_size="10GB")
            tokenizer.save_pretrained(output_path)
            
            del model
            del tokenizer
            gc.collect()
            torch.cuda.empty_cache()
            print('Merge complete.')
        else:
            print("Merged weights found. Skipping merge step.")
            
        queue.put((output_path, out_json))
    
    except Exception as e:
        error_msg = traceback.format_exc()
        print(error_msg)
        queue.put(error_msg)
        print(f"Error in merge_process: {e}")

# --- Scoring Logic ---
def score_outputs(outputs, test_target_name, ground_truths, out_json):
    results = []
    total_correct = 0
    total_samples = len(ground_truths)
    invalid_count = 0

    print(f"Calculating scores for {test_target_name}...")

    # Identify dataset type
    is_gsm8k = 'gsm8k' in test_target_name.lower()
    is_math = 'math' in test_target_name.lower()

    for i, output in enumerate(tqdm(outputs, desc="Scoring")):
        prediction_text = output.outputs[0].text
        gt_raw = ground_truths[i]
        
        is_correct = False
        extracted_pred = None
        clean_gt = None

        # --- GSM8k Scoring ---
        if is_gsm8k:
            # Logic: GT format in jsonl is usually "Reasoning.... #### 1234"
            # gsm8k_infer.py logic: temp_ans = item['answer'].split('#### ')[1]; int(replace(','))
            try:
                if '####' in str(gt_raw):
                    clean_gt_str = str(gt_raw).split('#### ')[1].replace(',', '').strip()
                else:
                    clean_gt_str = str(gt_raw).replace(',', '').strip()
                
                clean_gt = float(clean_gt_str)
            except:
                clean_gt = gt_raw # Keep raw if parsing fails

            extracted_pred = extract_answer_number_gsm8k(prediction_text)
            
            if extracted_pred is not None:
                # Comparison logic from gsm8k_infer.py
                try:
                    is_correct = (float(extracted_pred) == float(clean_gt)) or math_equal(extracted_pred, clean_gt)
                except:
                    is_correct = False
            else:
                is_correct = False
                invalid_count += 1

        # --- MATH Scoring ---
        elif is_math:
            # Logic: GT in jsonl is usually the solution text
            # MATH_infer.py logic: remove_boxed(util.last_boxed_only_string(solution))
            try:
                clean_gt = remove_boxed(util.last_boxed_only_string(str(gt_raw)))
            except:
                clean_gt = gt_raw

            # Logic from MATH_infer.py: process_results
            is_correct, extracted_pred = process_results_math(prediction_text, clean_gt)
            
            if not extracted_pred and not is_correct:
                invalid_count += 1

        results.append({
            "id": i,
            "prediction_full": prediction_text,
            "extracted_pred": extracted_pred,
            "ground_truth_raw": gt_raw,
            "ground_truth_clean": clean_gt,
            "is_correct": is_correct,
        })
        
        if is_correct:
            total_correct += 1

    avg_acc = 100.0 * total_correct / total_samples if total_samples > 0 else 0
    
    print("\n" + "="*40)
    print(f"FINAL RESULTS: {test_target_name}")
    print("="*40)
    print(f"Total Samples: {total_samples}")
    print(f"Invalid/No Answer Found: {invalid_count}")
    print(f"Accuracy: {avg_acc:.2f}%")
    print("="*40)

    os.makedirs(out_json, exist_ok=True)
    save_file = os.path.join(out_json, f'{test_target_name}.json')
    with open(save_file, "w", encoding="utf-8") as f:
        json.dump({
            "metrics": {
                "accuracy": avg_acc,
                "total": total_samples,
                "invalid": invalid_count
            },
            "details": results
        }, f, indent=2, ensure_ascii=False)
        
    return avg_acc

@draccus.wrap()
def main(mainCfg: MainConfig):
    print('='*120)
    set_seed_all(mainCfg.seed)
    
    # --- 1. Merge Step ---
    queue = Queue()
    p = Process(target=merge_process, args=(queue, mainCfg, False))
    p.start()
    merge_result = queue.get() 
    p.join()

    if merge_result is None or isinstance(merge_result, str): # Handle error string
        raise RuntimeError(f"Model merging failed: {merge_result}")
    
    model_path, out_json = merge_result

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model directory does not exist: {model_path}")
    
    print(f"Verified model path: {os.path.abspath(model_path)}")
    out_json = os.path.join(out_json, "results")
    print('Output JSON path: ', out_json)

    # --- 2. Initialize vLLM ---
    print("Initializing vLLM...")
    llm = LLM(
        model=model_path,
        dtype="bfloat16",
        gpu_memory_utilization=0.9, 
        max_model_len=mainCfg.infer.infer_max_seq_length,
        tensor_parallel_size=1,
    )

    # Stop tokens for CoT
    stop_tokens = ["Instruction:", "Instruction", "Response:", "Response",]
    sampling_params = SamplingParams(
        temperature=0, 
        top_p=1, 
        max_tokens=MAX_NEW_TOKENS, 
        stop=stop_tokens
    )
    
    start_time_total = datetime.now()
    final_res = {}
    all_task_acc = []

    # --- 3. Evaluation Loop ---
    try:
        # Loop through datasets defined in config
        for test_target_name in mainCfg.infer.datasets:
            print(f"Processing dataset: {test_target_name}")
            
            # Load Local Dataset
            dataset_path = f'./dataset/{test_target_name}/test.jsonl'
            if not os.path.exists(dataset_path):
                print(f"[Error] Local file not found: {dataset_path}. Skipping.")
                continue

            print(f"Loading local file: {dataset_path}")
            test_dataset = load_dataset("json", data_files=dataset_path, split='train')

            # Standardize Column Names
            # MATH: instruction -> question, output -> answer
            if 'instruction' in test_dataset.column_names:
                test_dataset = test_dataset.rename_column('instruction', 'question')
            if 'output' in test_dataset.column_names:
                test_dataset = test_dataset.rename_column('output', 'answer')

            ground_truths = test_dataset['answer']

            # Format Prompts
            print("Formatting prompts...")
            test_dataset = test_dataset.map(
                format_prompt,
                batched=True,
                batch_size=1000,
                desc="Formatting prompts"
            )
            prompts = test_dataset['prompt']

            # Generate
            print(f"Generating responses for {len(prompts)} samples...")
            start_time_task = datetime.now()
            
            outputs = llm.generate(prompts, sampling_params)
            
            end_time_task = datetime.now()
            print(f"Task {test_target_name} duration: {end_time_task - start_time_task}")
            
            # Score
            avg_acc = score_outputs(
                outputs=outputs,
                test_target_name=test_target_name,
                ground_truths=ground_truths,
                out_json=out_json
            )
            
            final_res[test_target_name] = avg_acc
            all_task_acc.append(avg_acc)
            
            del prompts
            del outputs
            del test_dataset
            # gc.collect()

    except Exception as e:
        print(f"Error during evaluation loop: {e}")
        traceback.print_exc()

    # --- 4. Final Report ---
    print('Accuracies per task:', all_task_acc)
    if all_task_acc:
        avg_score = sum(all_task_acc) / len(all_task_acc)
    else:
        avg_score = 0.0

    final_res['average_score'] = avg_score
    
    os.makedirs(out_json, exist_ok=True)
    save_file = os.path.join(out_json, 'FINAL.json')
    
    with open(save_file, "w", encoding="utf-8") as f:
        json.dump(final_res, f, indent=2, ensure_ascii=False)
        
    print(f"All Results saved to {save_file}, Overall Score: {avg_score:.2f}")

    end_time_total = datetime.now()
    print(f"Total execution time: {end_time_total - start_time_total}")

if __name__ == "__main__":
    # ------------------------------------------------------------------
    # EXPLANATION:
    # We must force the start method to 'spawn'.
    # Why? By default, Linux uses 'fork'. If the parent process has 
    # already initialized CUDA (even by just checking availability or 
    # running inside a VS Code debugger), 'fork' will copy that corrupted 
    # CUDA context to the child process, causing a RuntimeError.
    # 'spawn' starts a fresh interpreter, avoiding this issue entirely.
    # ------------------------------------------------------------------
    # try:
    #     import torch.multiprocessing as mp
    #     mp.set_start_method('spawn', force=True)
    # except RuntimeError as e:
    #     # If it's already set, just print a warning (safe to ignore usually)
    #     print(f"Warning: context already set: {e}")
    main()