|
|
|
|
|
|
|
|
""" |
|
|
并行评测脚本:针对 Qwen2.5-7B-Math(或你微调后的 HF 格式权重) |
|
|
|
|
|
单卡示例: |
|
|
python simple_valid.py \ |
|
|
--model_path /pfs/lichenyi/work/finetune_output_train1/checkpoint-300 \ |
|
|
--data_path /pfs/lichenyi/work/evaluation/valid.json \ |
|
|
--dtype bf16 \ |
|
|
--use_system \ |
|
|
--temperature 0.0 |
|
|
|
|
|
多卡示例(4 卡): |
|
|
torchrun --nproc_per_node 4 simple_valid.py \ |
|
|
--model_path /pfs/lichenyi/work/finetune_output_train1/checkpoint-300 \ |
|
|
--data_path /pfs/lichenyi/work/evaluation/valid.json \ |
|
|
--dtype bf16 \ |
|
|
--use_system \ |
|
|
--temperature 0.0 |
|
|
|
|
|
若不显式传 --out_path,将自动写入: |
|
|
/pfs/lichenyi/work/evaluation/predictions/predictions_<basename(model_path)>.json |
|
|
例如: |
|
|
/pfs/lichenyi/work/evaluation/predictions/predictions_checkpoint-300.json |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model( |
|
|
model_path: str, |
|
|
load_in_8bit: bool, |
|
|
load_in_4bit: bool, |
|
|
dtype: str, |
|
|
device_map="auto", |
|
|
): |
|
|
kwargs = {} |
|
|
if load_in_4bit: |
|
|
|
|
|
kwargs.update(dict(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)) |
|
|
elif load_in_8bit: |
|
|
|
|
|
kwargs.update(dict(load_in_8bit=True)) |
|
|
else: |
|
|
|
|
|
if dtype == "bf16" and torch.cuda.is_available(): |
|
|
kwargs.update(dict(dtype=torch.bfloat16)) |
|
|
else: |
|
|
kwargs.update(dict(dtype=torch.float16)) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
device_map=device_map, |
|
|
trust_remote_code=True, |
|
|
**kwargs, |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
use_fast=True, |
|
|
) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def canonicalize_human(value: str) -> str: |
|
|
|
|
|
return value.split(":::")[0].strip() |
|
|
|
|
|
def decode_only_new(gen_ids: torch.Tensor, prompt_len: int, tokenizer) -> str: |
|
|
new_tokens = gen_ids[0, prompt_len:] |
|
|
text = tokenizer.decode(new_tokens, skip_special_tokens=False) |
|
|
|
|
|
|
|
|
stop_markers = [] |
|
|
|
|
|
|
|
|
if getattr(tokenizer, "eos_token", None): |
|
|
stop_markers.append(tokenizer.eos_token) |
|
|
|
|
|
|
|
|
stop_markers.extend([ |
|
|
"<|im_end|>", |
|
|
"<|endoftext|>", |
|
|
"<end_of_text>", |
|
|
]) |
|
|
|
|
|
for m in stop_markers: |
|
|
if m and m in text: |
|
|
text = text.split(m)[0] |
|
|
break |
|
|
|
|
|
|
|
|
lines = text.splitlines() |
|
|
block = [] |
|
|
for ln in lines: |
|
|
if ln.strip() == "": |
|
|
break |
|
|
block.append(ln) |
|
|
text = "\n".join(block).strip() |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
def build_model_inputs(messages, tokenizer, device): |
|
|
""" |
|
|
兼容有/没有 chat_template 的 Qwen2.5-7B-Math: |
|
|
- 优先用 tokenizer.apply_chat_template |
|
|
- 如果你的 Math 模型没带 chat_template,则退化为简单字符串拼接 + tokenizer() |
|
|
""" |
|
|
|
|
|
try: |
|
|
model_inputs = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
model_inputs = {k: v.to(device) for k, v in model_inputs.items()} |
|
|
return model_inputs |
|
|
except Exception: |
|
|
|
|
|
|
|
|
text_parts = [] |
|
|
for m in messages: |
|
|
role = m["role"] |
|
|
content = m["content"] |
|
|
if role == "system": |
|
|
text_parts.append(f"[SYSTEM]\n{content}\n") |
|
|
elif role == "user": |
|
|
text_parts.append(f"[USER]\n{content}\n") |
|
|
|
|
|
text = "\n".join(text_parts) + "\n[ASSISTANT]\n" |
|
|
|
|
|
enc = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
) |
|
|
enc = {k: v.to(device) for k, v in enc.items()} |
|
|
return enc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_distributed(): |
|
|
""" |
|
|
如果用 torchrun 启动,就初始化分布式;否则退化为单进程。 |
|
|
返回: (distributed, rank, world_size, local_rank) |
|
|
""" |
|
|
world_size = int(os.environ.get("WORLD_SIZE", "1")) |
|
|
distributed = world_size > 1 |
|
|
|
|
|
if not distributed: |
|
|
return False, 0, 1, 0 |
|
|
|
|
|
dist.init_process_group(backend="nccl") |
|
|
rank = dist.get_rank() |
|
|
world_size = dist.get_world_size() |
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
|
|
|
|
torch.cuda.set_device(local_rank) |
|
|
|
|
|
return True, rank, world_size, local_rank |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--model_path", type=str, required=True, |
|
|
help="本地模型路径或HF模型名,如 /pfs/.../Qwen2.5-7B-Math") |
|
|
ap.add_argument("--data_path", type=str, required=True, help="测试集 JSON 路径") |
|
|
ap.add_argument( |
|
|
"--out_path", |
|
|
type=str, |
|
|
default="", |
|
|
help="输出预测 JSON 路径;留空则自动根据 model_path 生成", |
|
|
) |
|
|
ap.add_argument("--max_new_tokens", type=int, default=128) |
|
|
ap.add_argument("--temperature", type=float, default=0.1) |
|
|
ap.add_argument("--top_p", type=float, default=0.95) |
|
|
ap.add_argument("--load_in_8bit", action="store_true") |
|
|
ap.add_argument("--load_in_4bit", action="store_true") |
|
|
ap.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16") |
|
|
ap.add_argument("--use_system", action="store_true", help="把样本里的 system 也塞到对话中") |
|
|
args = ap.parse_args() |
|
|
|
|
|
|
|
|
distributed, rank, world_size, local_rank = setup_distributed() |
|
|
|
|
|
if distributed and rank == 0: |
|
|
print(f"[INFO] Distributed inference, world_size={world_size}") |
|
|
|
|
|
if distributed: |
|
|
|
|
|
device_map = {"": local_rank} |
|
|
else: |
|
|
device_map = "auto" |
|
|
|
|
|
|
|
|
if args.out_path: |
|
|
out_path = args.out_path |
|
|
os.makedirs(out_path, exist_ok=True) |
|
|
|
|
|
|
|
|
base_name = os.path.basename(os.path.normpath(args.model_path)) |
|
|
if not base_name: |
|
|
base_name = os.path.basename(args.model_path.rstrip("/")) |
|
|
|
|
|
out_path = os.path.join( |
|
|
args.out_path, |
|
|
f"predictions_{base_name}.json", |
|
|
) |
|
|
else: |
|
|
|
|
|
base_out_dir = "/pfs/lichenyi/work/evaluation/predictions" |
|
|
os.makedirs(base_out_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
base_name = os.path.basename(os.path.normpath(args.model_path)) |
|
|
if not base_name: |
|
|
base_name = os.path.basename(args.model_path.rstrip("/")) |
|
|
|
|
|
out_path = os.path.join( |
|
|
base_out_dir, |
|
|
f"predictions_{base_name}.json", |
|
|
) |
|
|
|
|
|
if rank == 0: |
|
|
print(f"[INFO] Output path: {out_path}") |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
print(f"[INFO] Loading model from {args.model_path} ...") |
|
|
|
|
|
model, tokenizer = load_model( |
|
|
args.model_path, |
|
|
args.load_in_8bit, |
|
|
args.load_in_4bit, |
|
|
args.dtype, |
|
|
device_map=device_map, |
|
|
) |
|
|
|
|
|
|
|
|
extra_eos_tokens = ["<|im_end|>", "<|endoftext|>", "<end_of_text>"] |
|
|
eos_ids = set() |
|
|
|
|
|
if getattr(tokenizer, "eos_token_id", None) is not None: |
|
|
if isinstance(tokenizer.eos_token_id, int): |
|
|
eos_ids.add(tokenizer.eos_token_id) |
|
|
else: |
|
|
eos_ids.update(tokenizer.eos_token_id) |
|
|
|
|
|
vocab = tokenizer.get_vocab() |
|
|
for tok in extra_eos_tokens: |
|
|
if tok in vocab: |
|
|
eos_ids.add(vocab[tok]) |
|
|
|
|
|
if len(eos_ids) == 0: |
|
|
eos_token_id = None |
|
|
elif len(eos_ids) == 1: |
|
|
eos_token_id = next(iter(eos_ids)) |
|
|
else: |
|
|
|
|
|
eos_token_id = list(eos_ids) |
|
|
|
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
print(f"[INFO] Loading dataset from {args.data_path} ...") |
|
|
|
|
|
with open(args.data_path, "r", encoding="utf-8") as f: |
|
|
dataset: List[Dict[str, Any]] = json.load(f) |
|
|
num_samples = len(dataset) |
|
|
|
|
|
|
|
|
indices = list(range(rank, num_samples, world_size)) |
|
|
|
|
|
if rank == 0: |
|
|
iter_indices = tqdm(indices, desc="Running inference") |
|
|
else: |
|
|
iter_indices = indices |
|
|
|
|
|
results = [] |
|
|
|
|
|
for idx in iter_indices: |
|
|
item = dataset[idx] |
|
|
|
|
|
|
|
|
system_text = item.get("system", "").strip() |
|
|
prompt_text = "" |
|
|
gt_text = "" |
|
|
|
|
|
|
|
|
for turn in item.get("conversations", []): |
|
|
if turn.get("from") == "human": |
|
|
prompt_text = canonicalize_human(turn.get("value", "")) |
|
|
elif turn.get("from") == "gpt": |
|
|
gt_text = turn.get("value", "").strip() |
|
|
|
|
|
|
|
|
messages = [] |
|
|
if args.use_system and system_text: |
|
|
messages.append({"role": "system", "content": system_text}) |
|
|
messages.append({"role": "user", "content": prompt_text}) |
|
|
|
|
|
|
|
|
model_inputs = build_model_inputs(messages, tokenizer, model.device) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
max_new_tokens=args.max_new_tokens, |
|
|
do_sample=args.temperature > 0, |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
) |
|
|
if eos_token_id is not None: |
|
|
gen_kwargs["eos_token_id"] = eos_token_id |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate( |
|
|
**model_inputs, |
|
|
**gen_kwargs, |
|
|
) |
|
|
|
|
|
prompt_len = model_inputs["input_ids"].shape[-1] |
|
|
pred = decode_only_new(output_ids, prompt_len, tokenizer) |
|
|
|
|
|
results.append({ |
|
|
"id": idx, |
|
|
"system": system_text if args.use_system else "", |
|
|
"prompt": prompt_text, |
|
|
"ground_truth": gt_text, |
|
|
"model_output": pred |
|
|
}) |
|
|
|
|
|
|
|
|
if distributed: |
|
|
|
|
|
all_results = [None for _ in range(world_size)] |
|
|
dist.all_gather_object(all_results, results) |
|
|
|
|
|
if rank == 0: |
|
|
merged = [] |
|
|
for part in all_results: |
|
|
merged.extend(part) |
|
|
merged.sort(key=lambda x: x["id"]) |
|
|
|
|
|
with open(out_path, "w", encoding="utf-8") as f: |
|
|
json.dump(merged, f, ensure_ascii=False, indent=2) |
|
|
print(f"[OK] 写入 {out_path} (共 {len(merged)} 条)") |
|
|
|
|
|
dist.barrier() |
|
|
dist.destroy_process_group() |
|
|
else: |
|
|
with open(out_path, "w", encoding="utf-8") as f: |
|
|
json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
print(f"[OK] 写入 {out_path} (共 {len(results)} 条)") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|