| """Evaluate ChatTime on the Time-MQA TSQA benchmark. |
| |
| Adapted from ``rats40k_adapter/eval_rats40k.py``. Generation is sharded across |
| ranks; each rank writes a JSONL shard, then rank 0 merges, computes per-group |
| metrics with the canonical Time-MQA evaluator and writes ``predictions.jsonl`` |
| plus ``metrics.json``. |
| |
| Over-long prompts are LEFT-truncated to ``--max_input_tokens`` so the series and |
| the trailing ask survive (ChatTime's context is only 4096 tokens while some |
| context-enhanced forecasting questions are much longer). |
| """ |
|
|
| import argparse |
| import copy |
| import json |
| import os |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed as dist |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from tsqa_common import ( |
| atomic_write_json, |
| balanced_limit, |
| build_prompt, |
| build_result_row, |
| compute_tsqa_metrics, |
| load_tsqa_records, |
| resolve_data_file, |
| write_jsonl, |
| ) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Evaluate ChatTime on Time-MQA TSQA.") |
| parser.add_argument("--model_path", required=True) |
| parser.add_argument("--adapter_path", default=None) |
| parser.add_argument("--data_root", default="/mnt/share01/sqk/datasets/Time-MQA_TSQA/tmp") |
| parser.add_argument("--eval_file", default="eval.jsonl") |
| parser.add_argument("--output_dir", required=True) |
| parser.add_argument("--max_eval_samples", type=int, default=None) |
| parser.add_argument("--eval_batch_size", type=int, default=4) |
| parser.add_argument("--max_input_tokens", type=int, default=3840) |
| parser.add_argument("--max_new_tokens", type=int, default=256) |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--top_p", type=float, default=1.0) |
| parser.add_argument("--top_k", type=int, default=50) |
| parser.add_argument("--repetition_penalty", type=float, default=1.0) |
| parser.add_argument("--torch_dtype", choices=["auto", "bf16", "fp16", "fp32"], default="fp16") |
| parser.add_argument("--allow_hf_download", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def init_distributed(): |
| world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| rank = int(os.environ.get("RANK", "0")) |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(local_rank) |
| if world_size > 1 and not dist.is_initialized(): |
| dist.init_process_group(backend="nccl") |
| return rank, local_rank, world_size |
|
|
|
|
| def dtype_from_arg(value): |
| if value == "auto": |
| return "auto" |
| if value == "bf16": |
| return torch.bfloat16 |
| if value == "fp16": |
| return torch.float16 |
| return torch.float32 |
|
|
|
|
| def load_model_and_tokenizer(args, local_rank): |
| local_files_only = not args.allow_hf_download |
| tokenizer = AutoTokenizer.from_pretrained( |
| args.model_path, |
| trust_remote_code=True, |
| local_files_only=local_files_only, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "left" |
| |
| tokenizer.truncation_side = "left" |
|
|
| device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_path, |
| trust_remote_code=True, |
| torch_dtype=dtype_from_arg(args.torch_dtype), |
| low_cpu_mem_usage=True, |
| device_map={"": device} if torch.cuda.is_available() else None, |
| local_files_only=local_files_only, |
| ) |
|
|
| if args.adapter_path: |
| from peft import PeftModel |
|
|
| model = PeftModel.from_pretrained( |
| model, |
| args.adapter_path, |
| local_files_only=local_files_only, |
| ) |
|
|
| model.eval() |
| return model, tokenizer, device |
|
|
|
|
| def generate_responses(model, tokenizer, device, prompts, args): |
| model_context = getattr(model.config, "max_position_embeddings", None) |
| max_input_tokens = args.max_input_tokens |
| if model_context and max_input_tokens + args.max_new_tokens > model_context: |
| max_input_tokens = max(1, model_context - args.max_new_tokens) |
|
|
| inputs = tokenizer( |
| prompts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_input_tokens, |
| ) |
| inputs = {key: value.to(device) for key, value in inputs.items()} |
|
|
| do_sample = args.temperature > 0 |
| generation_config = copy.deepcopy(model.generation_config) |
| generation_config.do_sample = do_sample |
| generation_config.pad_token_id = tokenizer.pad_token_id |
| generation_config.eos_token_id = tokenizer.eos_token_id |
| if do_sample: |
| generation_config.temperature = args.temperature |
| generation_config.top_p = args.top_p |
| generation_config.top_k = args.top_k |
| else: |
| generation_config.temperature = None |
| generation_config.top_p = None |
| generation_config.top_k = None |
|
|
| with torch.inference_mode(): |
| output = model.generate( |
| **inputs, |
| generation_config=generation_config, |
| max_new_tokens=args.max_new_tokens, |
| repetition_penalty=args.repetition_penalty, |
| ) |
| new_tokens = output[:, inputs["input_ids"].shape[-1]:] |
| return [ |
| response.strip() |
| for response in tokenizer.batch_decode(new_tokens, skip_special_tokens=True) |
| ] |
|
|
|
|
| def main(): |
| args = parse_args() |
| rank, local_rank, world_size = init_distributed() |
| output_dir = Path(args.output_dir) |
| shard_dir = output_dir / "shards" |
| shard_dir.mkdir(parents=True, exist_ok=True) |
|
|
| data_path = resolve_data_file(args.data_root, args.eval_file) |
| records = load_tsqa_records(data_path) |
| records = balanced_limit(records, args.max_eval_samples) |
| shard_records = records[rank::world_size] |
|
|
| if rank == 0: |
| print(f"Dataset: {data_path}") |
| print(f"Total samples: {len(records)}") |
| print(f"World size: {world_size}") |
| print(f"Per-device eval batch size: {args.eval_batch_size}") |
| print(f"Maximum global eval batch size: {args.eval_batch_size * world_size}") |
| print(f"Output dir: {output_dir}") |
|
|
| model, tokenizer, device = load_model_and_tokenizer(args, local_rank) |
|
|
| results = [] |
| batch_size = max(1, args.eval_batch_size) |
| batch_starts = range(0, len(shard_records), batch_size) |
| for start in tqdm( |
| batch_starts, |
| total=(len(shard_records) + batch_size - 1) // batch_size, |
| desc=f"rank {rank}", |
| disable=rank != 0, |
| ): |
| batch = shard_records[start:start + batch_size] |
| prompts = [ |
| build_prompt(record["question"], record.get("time_series") or []) |
| for record in batch |
| ] |
| responses = generate_responses(model, tokenizer, device, prompts, args) |
| for record, response in zip(batch, responses): |
| results.append(build_result_row(record, response)) |
|
|
| shard_path = shard_dir / f"predictions.rank{rank}.jsonl" |
| write_jsonl(results, shard_path) |
|
|
| if world_size > 1: |
| dist.barrier() |
|
|
| if rank == 0: |
| merged = [] |
| for shard_rank in range(world_size): |
| path = shard_dir / f"predictions.rank{shard_rank}.jsonl" |
| with open(path, "r", encoding="utf-8") as handle: |
| for line in handle: |
| line = line.strip() |
| if line: |
| merged.append(json.loads(line)) |
| merged.sort(key=lambda row: str(row.get("id"))) |
|
|
| predictions_path = output_dir / "predictions.jsonl" |
| metrics_path = output_dir / "metrics.json" |
| write_jsonl(merged, predictions_path) |
| metrics = compute_tsqa_metrics(merged) |
| atomic_write_json(metrics, metrics_path) |
| print(json.dumps(metrics, ensure_ascii=False, indent=2)) |
| print(f"Saved predictions: {predictions_path}") |
| print(f"Saved metrics: {metrics_path}") |
|
|
| if world_size > 1: |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|