File size: 3,884 Bytes
b1e25b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

import argparse, os, jsonlines, torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from utils.metrics import qa_f1_score, qa_em_score
THINK_END_ID = 151668        # "</think>" token id for Qwen3

# --------------------------------------------------
def strip_think(token_ids):
    try:
        cut = len(token_ids) - token_ids[::-1].index(THINK_END_ID)
        return token_ids[cut:]
    except ValueError:
        return token_ids

def main():
    # ---------- CLI ----------
    parser = argparse.ArgumentParser(
        description="Evaluate HotpotQA JSONL with Transformers + Qwen3-8B"
    )
    parser.add_argument("-i", "--input", required=True,
                        help="Path to input JSONL file")
    parser.add_argument("--model", required=True,
                        help="HF model name, e.g. Qwen/Qwen3-8B")
    parser.add_argument("-d", "--devices", default="0",
                        help="CUDA_VISIBLE_DEVICES (comma-separated)")
    parser.add_argument("-t", "--temperature", type=float, default=0.5,
                        help="Sampling temperature")
    parser.add_argument("-k", "--max_tokens", type=int, default=40,
                        help="max_new_tokens")
    args = parser.parse_args()



    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True
    )
    gen_cfg = GenerationConfig(
        temperature=args.temperature,
        max_new_tokens=args.max_tokens,
        do_sample=args.temperature > 0
    )


    with jsonlines.open(args.input) as reader:
        data = list(reader)

    total_f1 = total_em = 0.0

    for idx, item in enumerate(data):
        question = item.get("input", "")
        context  = item.get("context", "")
        answers  = item.get("answers", [])
        if not answers:
            print(f"[{idx}] no gold answer, skip")
            continue
        gold = answers[0]
        print(gold)

        # ----- Prompt -----
        prompt = (
            "Answer the question based on the given passages. "
            "Only give me your answer and do not output any other words.\n"
            "Passages:\n"
            f"{context}\n"
            f"Question: {question}\n"
            "Answer:"
        )
        messages = [{"role": "user", "content": prompt}]
        chat_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False
        )
        inputs = tokenizer([chat_text], return_tensors="pt").to(model.device)


        # ----- Generate -----
        try:
            with torch.no_grad():
                outputs = model.generate(**inputs, max_new_tokens=args.max_tokens)
        except ValueError as e:
            if "position ids exceed" in str(e).lower() or "sequence length" in str(e).lower():
                print(f"[{idx}] prompt too long – skipped")
                continue
            raise
        print("im here")
        new_ids = outputs[0][len(inputs.input_ids[0]):].tolist() 
        try:
            index = len(new_ids) - new_ids[::-1].index(151668)
        except ValueError:
            index = 0
        answer = tokenizer.decode(new_ids[index:], skip_special_tokens=True).strip("\n")
        answer = answer.strip()  

        # ----- Score -----
        f1 = qa_f1_score(answer, gold)
        em = qa_em_score(answer, gold)
        total_f1 += f1
        total_em += em

        print(f"[{idx}] Q: {question}")
        print(f"    Resp: {answer!r} | Gold: {gold!r}")
        print(f"    F1={f1:.2f}, EM={em:.2f}")

    n = len(data)
    print(f"\nOverall F1: {total_f1/n:.4f}")
    print(f"Overall EM: {total_em/n:.4f}")

if __name__ == "__main__":
    main()