ChatTime / tsqa_adapter /eval_tsqa.py
a12354's picture
Add files using upload-large-folder tool
8d2b389 verified
Raw
History Blame Contribute Delete
8.09 kB
"""Evaluate ChatTime on the Time-MQA TSQA benchmark.
Adapted from ``rats40k_adapter/eval_rats40k.py``. Generation is sharded across
ranks; each rank writes a JSONL shard, then rank 0 merges, computes per-group
metrics with the canonical Time-MQA evaluator and writes ``predictions.jsonl``
plus ``metrics.json``.
Over-long prompts are LEFT-truncated to ``--max_input_tokens`` so the series and
the trailing ask survive (ChatTime's context is only 4096 tokens while some
context-enhanced forecasting questions are much longer).
"""
import argparse
import copy
import json
import os
from pathlib import Path
import torch
import torch.distributed as dist
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from tsqa_common import (
atomic_write_json,
balanced_limit,
build_prompt,
build_result_row,
compute_tsqa_metrics,
load_tsqa_records,
resolve_data_file,
write_jsonl,
)
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate ChatTime on Time-MQA TSQA.")
parser.add_argument("--model_path", required=True)
parser.add_argument("--adapter_path", default=None)
parser.add_argument("--data_root", default="/mnt/share01/sqk/datasets/Time-MQA_TSQA/tmp")
parser.add_argument("--eval_file", default="eval.jsonl")
parser.add_argument("--output_dir", required=True)
parser.add_argument("--max_eval_samples", type=int, default=None)
parser.add_argument("--eval_batch_size", type=int, default=4)
parser.add_argument("--max_input_tokens", type=int, default=3840)
parser.add_argument("--max_new_tokens", type=int, default=256)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_p", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--repetition_penalty", type=float, default=1.0)
parser.add_argument("--torch_dtype", choices=["auto", "bf16", "fp16", "fp32"], default="fp16")
parser.add_argument("--allow_hf_download", action="store_true")
return parser.parse_args()
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
if world_size > 1 and not dist.is_initialized():
dist.init_process_group(backend="nccl")
return rank, local_rank, world_size
def dtype_from_arg(value):
if value == "auto":
return "auto"
if value == "bf16":
return torch.bfloat16
if value == "fp16":
return torch.float16
return torch.float32
def load_model_and_tokenizer(args, local_rank):
local_files_only = not args.allow_hf_download
tokenizer = AutoTokenizer.from_pretrained(
args.model_path,
trust_remote_code=True,
local_files_only=local_files_only,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Keep the series and the trailing ask when a prompt is too long.
tokenizer.truncation_side = "left"
device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
trust_remote_code=True,
torch_dtype=dtype_from_arg(args.torch_dtype),
low_cpu_mem_usage=True,
device_map={"": device} if torch.cuda.is_available() else None,
local_files_only=local_files_only,
)
if args.adapter_path:
from peft import PeftModel
model = PeftModel.from_pretrained(
model,
args.adapter_path,
local_files_only=local_files_only,
)
model.eval()
return model, tokenizer, device
def generate_responses(model, tokenizer, device, prompts, args):
model_context = getattr(model.config, "max_position_embeddings", None)
max_input_tokens = args.max_input_tokens
if model_context and max_input_tokens + args.max_new_tokens > model_context:
max_input_tokens = max(1, model_context - args.max_new_tokens)
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_input_tokens,
)
inputs = {key: value.to(device) for key, value in inputs.items()}
do_sample = args.temperature > 0
generation_config = copy.deepcopy(model.generation_config)
generation_config.do_sample = do_sample
generation_config.pad_token_id = tokenizer.pad_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
if do_sample:
generation_config.temperature = args.temperature
generation_config.top_p = args.top_p
generation_config.top_k = args.top_k
else:
generation_config.temperature = None
generation_config.top_p = None
generation_config.top_k = None
with torch.inference_mode():
output = model.generate(
**inputs,
generation_config=generation_config,
max_new_tokens=args.max_new_tokens,
repetition_penalty=args.repetition_penalty,
)
new_tokens = output[:, inputs["input_ids"].shape[-1]:]
return [
response.strip()
for response in tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
]
def main():
args = parse_args()
rank, local_rank, world_size = init_distributed()
output_dir = Path(args.output_dir)
shard_dir = output_dir / "shards"
shard_dir.mkdir(parents=True, exist_ok=True)
data_path = resolve_data_file(args.data_root, args.eval_file)
records = load_tsqa_records(data_path)
records = balanced_limit(records, args.max_eval_samples)
shard_records = records[rank::world_size]
if rank == 0:
print(f"Dataset: {data_path}")
print(f"Total samples: {len(records)}")
print(f"World size: {world_size}")
print(f"Per-device eval batch size: {args.eval_batch_size}")
print(f"Maximum global eval batch size: {args.eval_batch_size * world_size}")
print(f"Output dir: {output_dir}")
model, tokenizer, device = load_model_and_tokenizer(args, local_rank)
results = []
batch_size = max(1, args.eval_batch_size)
batch_starts = range(0, len(shard_records), batch_size)
for start in tqdm(
batch_starts,
total=(len(shard_records) + batch_size - 1) // batch_size,
desc=f"rank {rank}",
disable=rank != 0,
):
batch = shard_records[start:start + batch_size]
prompts = [
build_prompt(record["question"], record.get("time_series") or [])
for record in batch
]
responses = generate_responses(model, tokenizer, device, prompts, args)
for record, response in zip(batch, responses):
results.append(build_result_row(record, response))
shard_path = shard_dir / f"predictions.rank{rank}.jsonl"
write_jsonl(results, shard_path)
if world_size > 1:
dist.barrier()
if rank == 0:
merged = []
for shard_rank in range(world_size):
path = shard_dir / f"predictions.rank{shard_rank}.jsonl"
with open(path, "r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if line:
merged.append(json.loads(line))
merged.sort(key=lambda row: str(row.get("id")))
predictions_path = output_dir / "predictions.jsonl"
metrics_path = output_dir / "metrics.json"
write_jsonl(merged, predictions_path)
metrics = compute_tsqa_metrics(merged)
atomic_write_json(metrics, metrics_path)
print(json.dumps(metrics, ensure_ascii=False, indent=2))
print(f"Saved predictions: {predictions_path}")
print(f"Saved metrics: {metrics_path}")
if world_size > 1:
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()