File size: 4,244 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import unsloth  # noqa: I001, F401
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoTokenizer
from unsloth import FastLanguageModel

from linalg_zero.sft.tool_calling_accuracy import ToolCallingAccuracyCallback


def load_unmerged():
    path = "results/LinalgZero-SFT-LoRA/checkpoint-400-best"
    # path = "results/LinalgZero-SFT-LoRA-110/checkpoint-110"
    tokenizer = AutoTokenizer.from_pretrained(path)
    print(f"Tokenizer vocab size: {len(tokenizer)}")

    model, _ = FastLanguageModel.from_pretrained(
        model_name="Qwen/Qwen2.5-3B",
        max_seq_length=8192,
        load_in_4bit=False,
        fast_inference=False,
    )

    model = PeftModel.from_pretrained(
        model,
        path,
        is_trainable=False,
    )

    # tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-LoRA")
    # model.push_to_hub("atomwalk12/LinalgZero-SFT-LoRA")

    FastLanguageModel.for_inference(model)

    return model, tokenizer


def load_merged():
    # Best models
    # Notice that best LoRA is checkpoint 400, while best merged is 300
    checkpoint_path = "results/LinalgZero-SFT/checkpoint-300-best"
    # checkpoint_path = "results/LinalgZero-SFT-merged"

    # checkpoint_path = "atomwalk12/LinalgZero-SFT-merged"
    # checkpoint_path = "atomwalk12/LinalgZero-SFT"

    # GRPO prep.
    # DONE
    # checkpoint_path = "results/LinalgZero-SFT-110/checkpoint-110"
    # checkpoint_path = "results/LinalgZero-SFT-105/checkpoint-105"

    # DONE
    # checkpoint_path = "results/LinalgZero-SFT-110-checkpoint-300/checkpoint-300"

    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    print(f"Tokenizer vocab size: {len(tokenizer)}")

    model, tok2 = FastLanguageModel.from_pretrained(
        model_name=checkpoint_path,
        max_seq_length=8192,
        load_in_4bit=False,
        fast_inference=False,
    )
    assert len(tok2) == len(tokenizer)

    # Best models
    model.push_to_hub("atomwalk12/LinalgZero-SFT")
    tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT")

    # model.push_to_hub("atomwalk12/LinalgZero-SFT-merged")
    # tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-merged")

    # GRPO prep.
    # DONE
    # model.push_to_hub("atomwalk12/LinalgZero-SFT-105")
    # tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-105")

    # DONE
    # model.push_to_hub("atomwalk12/LinalgZero-SFT-110")
    # tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-110")

    # DONE
    # model.push_to_hub("atomwalk12/LinalgZero-SFT-110-checkpoint-300")
    # tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-110-checkpoint-300")

    # model.push_to_hub("atomwalk12/LinalgZero-SFT")
    # tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT")

    FastLanguageModel.for_inference(model)

    return model, tokenizer


model, tokenizer = load_unmerged()

eval_ds = load_dataset("atomwalk12/linalgzero-sft", split="test")  # or whatever split you used

cb = ToolCallingAccuracyCallback(
    model_name="atomwalk12/LinAlgZero-SFT-merged",
    dataset_name="atomwalk12/linalgzero",
    eval_dataset=eval_ds,
)
gen_config = cb.get_generation_config(max_new_tokens=800, tokenizer=tokenizer)


def generate_like_sft_eval(sample_idx: int = 0):
    sample = eval_ds[sample_idx]
    context = list(sample["messages"])
    tools = sample["tools"]
    print(f"Query is: {sample['messages'][-1]['content']}")

    prompt = tokenizer.apply_chat_template(
        context,
        tools=tools,
        tokenize=False,
        add_generation_prompt=True,
    )
    prompt = prompt

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            tokenizer=tokenizer,
            **gen_config,
        )

    # Decode only the generated continuation (optional: mimic callback's decoding)
    prompt_len = inputs["input_ids"].shape[1]
    gen_tokens = outputs[:, prompt_len:]
    text = tokenizer.decode(gen_tokens[0], skip_special_tokens=False)
    print(text)
    return text


result = generate_like_sft_eval(0)