File size: 4,016 Bytes
a80f6e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import json
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
DATASETS = ["abstract_algebra", "college_mathematics"]
TOP_K = 5
SAVE_DIR = "data/prob_method"

STOPWORDS = {
    "that", "of", "the", "and", "to", "in", "for", "with", "on", "at", "by", "from", "as", "is", "are", "was", "were",
    "be", "been", "being", "a", "an", "but", "or", "if", "because", "so", "do", "does", "did", "not", "no", "yes",
    "can", "could", "should", "would", "will", "shall", "may", "might", "must", "this", "that", "these", "those",
    "it", "its", "he", "she", "they", "them", "his", "her", "their", "you", "your", "we", "our", "i", "me", "my",
    "mine", "who", "whom", "which", "what", "when", "where", "why", "how", "also", "than", "then", "there", "here",
    "such", "other", "some", "any", "all", "each", "every", "either", "neither", "both", "few", "more", "most", "much", "many"
}

def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def get_per_token_prob(model, tokenizer, prompt):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids.to(model.device)
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
        # 对每个位置,softmax得到概率
        probs = torch.softmax(logits, dim=-1)
        # 取每个位置真实token的概率
        shift_probs = probs[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        # gather出每个真实token的概率
        token_probs = shift_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
        token_probs = token_probs.squeeze(0).cpu().numpy()
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).cpu())
    clean_tokens = [token.lstrip('Ġ') for token in tokens[1:]]
    return clean_tokens, token_probs  # skip the first token (BOS)

def format_prompt(q):
    # Only use the question itself as prompt
    return q['question']

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
    model.eval()
    for subset in DATASETS:
        print(f"Processing {subset}...")
        ds = load_dataset("cais/mmlu", subset, split="test")
        save_path = os.path.join(SAVE_DIR, subset)
        ensure_dir(save_path)
        for idx, item in enumerate(ds):
            options = {chr(65+i): item['choices'][i] for i in range(len(item['choices']))}
            answer = chr(65 + item['answer'])
            qdict = {"question": item['question'], "options": options, "answer": answer}
            prompt = format_prompt(qdict)
            tokens, per_token_ppl = get_per_token_prob(model, tokenizer, prompt)
            # 过滤掉长度为1和stopwords的token
            token_info = [
                (i, tokens[i], per_token_ppl[i])
                for i in range(len(tokens))
                if len(tokens[i]) > 1 and tokens[i].lower() not in STOPWORDS
            ]
            if len(token_info) < TOP_K:
                topk = sorted(token_info, key=lambda x: x[2])[:TOP_K]
            else:
                topk = sorted(token_info, key=lambda x: x[2])[:TOP_K]
            topk_tokens = [x[1] for x in topk]
            uncertainties = [float(x[2]) for x in topk]
            out = {
                "question": qdict["question"],
                "options": qdict["options"],
                "answer": qdict["answer"],
                "topk_tokens": topk_tokens,
                "uncertainties": uncertainties
            }
            fname = os.path.join(save_path, f"question_{idx:04d}.json")
            with open(fname, 'w', encoding='utf-8') as f:
                json.dump(out, f, ensure_ascii=False, indent=2)
            if idx % 20 == 0:
                print(f"Saved {fname}")

if __name__ == "__main__":
    main()