linalg-zero / linalg_zero /sft /scripts /evaluate_merged_model.py
atomwalk12's picture
initial commit
0dd6c2f
import unsloth # noqa: I001, F401
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoTokenizer
from unsloth import FastLanguageModel
from linalg_zero.sft.tool_calling_accuracy import ToolCallingAccuracyCallback
def load_unmerged():
path = "results/LinalgZero-SFT-LoRA/checkpoint-400-best"
# path = "results/LinalgZero-SFT-LoRA-110/checkpoint-110"
tokenizer = AutoTokenizer.from_pretrained(path)
print(f"Tokenizer vocab size: {len(tokenizer)}")
model, _ = FastLanguageModel.from_pretrained(
model_name="Qwen/Qwen2.5-3B",
max_seq_length=8192,
load_in_4bit=False,
fast_inference=False,
)
model = PeftModel.from_pretrained(
model,
path,
is_trainable=False,
)
# tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-LoRA")
# model.push_to_hub("atomwalk12/LinalgZero-SFT-LoRA")
FastLanguageModel.for_inference(model)
return model, tokenizer
def load_merged():
# Best models
# Notice that best LoRA is checkpoint 400, while best merged is 300
checkpoint_path = "results/LinalgZero-SFT/checkpoint-300-best"
# checkpoint_path = "results/LinalgZero-SFT-merged"
# checkpoint_path = "atomwalk12/LinalgZero-SFT-merged"
# checkpoint_path = "atomwalk12/LinalgZero-SFT"
# GRPO prep.
# DONE
# checkpoint_path = "results/LinalgZero-SFT-110/checkpoint-110"
# checkpoint_path = "results/LinalgZero-SFT-105/checkpoint-105"
# DONE
# checkpoint_path = "results/LinalgZero-SFT-110-checkpoint-300/checkpoint-300"
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
print(f"Tokenizer vocab size: {len(tokenizer)}")
model, tok2 = FastLanguageModel.from_pretrained(
model_name=checkpoint_path,
max_seq_length=8192,
load_in_4bit=False,
fast_inference=False,
)
assert len(tok2) == len(tokenizer)
# Best models
model.push_to_hub("atomwalk12/LinalgZero-SFT")
tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT")
# model.push_to_hub("atomwalk12/LinalgZero-SFT-merged")
# tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-merged")
# GRPO prep.
# DONE
# model.push_to_hub("atomwalk12/LinalgZero-SFT-105")
# tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-105")
# DONE
# model.push_to_hub("atomwalk12/LinalgZero-SFT-110")
# tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-110")
# DONE
# model.push_to_hub("atomwalk12/LinalgZero-SFT-110-checkpoint-300")
# tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT-110-checkpoint-300")
# model.push_to_hub("atomwalk12/LinalgZero-SFT")
# tokenizer.push_to_hub("atomwalk12/LinalgZero-SFT")
FastLanguageModel.for_inference(model)
return model, tokenizer
model, tokenizer = load_unmerged()
eval_ds = load_dataset("atomwalk12/linalgzero-sft", split="test") # or whatever split you used
cb = ToolCallingAccuracyCallback(
model_name="atomwalk12/LinAlgZero-SFT-merged",
dataset_name="atomwalk12/linalgzero",
eval_dataset=eval_ds,
)
gen_config = cb.get_generation_config(max_new_tokens=800, tokenizer=tokenizer)
def generate_like_sft_eval(sample_idx: int = 0):
sample = eval_ds[sample_idx]
context = list(sample["messages"])
tools = sample["tools"]
print(f"Query is: {sample['messages'][-1]['content']}")
prompt = tokenizer.apply_chat_template(
context,
tools=tools,
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
tokenizer=tokenizer,
**gen_config,
)
# Decode only the generated continuation (optional: mimic callback's decoding)
prompt_len = inputs["input_ids"].shape[1]
gen_tokens = outputs[:, prompt_len:]
text = tokenizer.decode(gen_tokens[0], skip_special_tokens=False)
print(text)
return text
result = generate_like_sft_eval(0)