File size: 3,470 Bytes
b66371b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""๊ณตํ†ต ํ‰๊ฐ€ ์Šคํฌ๋ฆฝํŠธ: vLLM ์„œ๋ฒ„์— ์—ฐ๊ฒฐํ•˜์—ฌ HRM8K ์ „์ฒด 841๋ฌธ์ œ ํ‰๊ฐ€ (temperature=0)"""
import os, json, re, sys, asyncio
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from openai import OpenAI

MATH_SYSTEM_PROMPT = """์ฃผ์–ด์ง„ ์ˆ˜ํ•™ ๋ฌธ์ œ๋ฅผ ๋‹จ๊ณ„๋ณ„๋กœ ํ’€๊ณ  ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”.
๋ฐ˜๋“œ์‹œ ์ตœ์ข… ๋‹ต๋ณ€์„ \\boxed{์ •์ˆ˜} ํ˜•์‹์œผ๋กœ ๋งˆ์ง€๋ง‰ ์ค„์— ์ถœ๋ ฅํ•˜์„ธ์š”.
์˜ˆ์‹œ: \\boxed{42}"""

def extract_boxed(text):
    m = re.findall(r'\\boxed\{([^}]+)\}', text)
    return m[-1].strip() if m else None

def normalize(a):
    if a is None: return None
    s = str(a).replace(",","").replace(" ","").strip()
    try:
        n = float(s)
        return str(int(n)) if n == int(n) else str(n)
    except: return s

def check(pred, gt):
    p, g = normalize(pred), normalize(gt)
    return p is not None and g is not None and p == g

async def evaluate(label="", save_path=None):
    client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")
    model_name = client.models.list().data[0].id
    print(f"๋ชจ๋ธ: {model_name}")

    with open("data/HRM8k_eval.json") as f:
        data = json.load(f)
    print(f"ํ‰๊ฐ€: {len(data)}๊ฐœ (temperature=0, max_tokens=2048)")

    llm = ChatOpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123",
                     model=model_name, temperature=0, max_tokens=2048)
    prompt = ChatPromptTemplate([("user", "{sp}\n\n{q}")]).partial(sp=MATH_SYSTEM_PROMPT)
    chain = prompt | llm | StrOutputParser()
    inputs = [{"q": item["question"]} for item in data]
    results = await chain.abatch(inputs, config={"max_concurrency": 400})

    by_src = {}
    details = []
    for item, res in zip(data, results):
        s = item.get("source", "?")
        if s not in by_src: by_src[s] = {"correct": 0, "total": 0, "no_boxed": 0}
        by_src[s]["total"] += 1
        pred = extract_boxed(res)
        is_correct = False
        if pred is None:
            by_src[s]["no_boxed"] += 1
        elif check(pred, item["answer"]):
            by_src[s]["correct"] += 1
            is_correct = True
        details.append({
            "question": item["question"][:80],
            "source": s,
            "gt": str(item["answer"])[-30:] if isinstance(item["answer"], str) else str(item["answer"]),
            "pred": pred,
            "correct": is_correct,
        })

    tc = sum(v["correct"] for v in by_src.values())
    tt = sum(v["total"] for v in by_src.values())
    print(f"\n=== {label} ๊ฒฐ๊ณผ (temperature=0) ===")
    for s in sorted(by_src):
        v = by_src[s]
        print(f"  [{s.upper()}] {v['correct']}/{v['total']} ({v['correct']/v['total']*100:.1f}%) | boxed๋ฏธ์ถœ๋ ฅ: {v['no_boxed']}")
    print(f"  [์ „์ฒด] {tc}/{tt} ({tc/tt*100:.1f}%)")

    result_obj = {"label": label, "correct": tc, "total": tt, "accuracy": tc/tt*100, "by_source": by_src}

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        with open(save_path, "w") as f:
            json.dump({"result": result_obj, "details": details}, f, ensure_ascii=False, indent=2)
        print(f"  ๊ฒฐ๊ณผ ์ €์žฅ: {save_path}")

    return result_obj

if __name__ == "__main__":
    label = sys.argv[1] if len(sys.argv) > 1 else "eval"
    save_path = sys.argv[2] if len(sys.argv) > 2 else None
    asyncio.run(evaluate(label, save_path))