"""Evaluate ChatTime on the Time-MQA TSQA benchmark. Adapted from ``rats40k_adapter/eval_rats40k.py``. Generation is sharded across ranks; each rank writes a JSONL shard, then rank 0 merges, computes per-group metrics with the canonical Time-MQA evaluator and writes ``predictions.jsonl`` plus ``metrics.json``. Over-long prompts are LEFT-truncated to ``--max_input_tokens`` so the series and the trailing ask survive (ChatTime's context is only 4096 tokens while some context-enhanced forecasting questions are much longer). """ import argparse import copy import json import os from pathlib import Path import torch import torch.distributed as dist from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from tsqa_common import ( atomic_write_json, balanced_limit, build_prompt, build_result_row, compute_tsqa_metrics, load_tsqa_records, resolve_data_file, write_jsonl, ) def parse_args(): parser = argparse.ArgumentParser(description="Evaluate ChatTime on Time-MQA TSQA.") parser.add_argument("--model_path", required=True) parser.add_argument("--adapter_path", default=None) parser.add_argument("--data_root", default="/mnt/share01/sqk/datasets/Time-MQA_TSQA/tmp") parser.add_argument("--eval_file", default="eval.jsonl") parser.add_argument("--output_dir", required=True) parser.add_argument("--max_eval_samples", type=int, default=None) parser.add_argument("--eval_batch_size", type=int, default=4) parser.add_argument("--max_input_tokens", type=int, default=3840) parser.add_argument("--max_new_tokens", type=int, default=256) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--top_p", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=50) parser.add_argument("--repetition_penalty", type=float, default=1.0) parser.add_argument("--torch_dtype", choices=["auto", "bf16", "fp16", "fp32"], default="fp16") parser.add_argument("--allow_hf_download", action="store_true") return parser.parse_args() def init_distributed(): world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) if world_size > 1 and not dist.is_initialized(): dist.init_process_group(backend="nccl") return rank, local_rank, world_size def dtype_from_arg(value): if value == "auto": return "auto" if value == "bf16": return torch.bfloat16 if value == "fp16": return torch.float16 return torch.float32 def load_model_and_tokenizer(args, local_rank): local_files_only = not args.allow_hf_download tokenizer = AutoTokenizer.from_pretrained( args.model_path, trust_remote_code=True, local_files_only=local_files_only, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" # Keep the series and the trailing ask when a prompt is too long. tokenizer.truncation_side = "left" device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained( args.model_path, trust_remote_code=True, torch_dtype=dtype_from_arg(args.torch_dtype), low_cpu_mem_usage=True, device_map={"": device} if torch.cuda.is_available() else None, local_files_only=local_files_only, ) if args.adapter_path: from peft import PeftModel model = PeftModel.from_pretrained( model, args.adapter_path, local_files_only=local_files_only, ) model.eval() return model, tokenizer, device def generate_responses(model, tokenizer, device, prompts, args): model_context = getattr(model.config, "max_position_embeddings", None) max_input_tokens = args.max_input_tokens if model_context and max_input_tokens + args.max_new_tokens > model_context: max_input_tokens = max(1, model_context - args.max_new_tokens) inputs = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_input_tokens, ) inputs = {key: value.to(device) for key, value in inputs.items()} do_sample = args.temperature > 0 generation_config = copy.deepcopy(model.generation_config) generation_config.do_sample = do_sample generation_config.pad_token_id = tokenizer.pad_token_id generation_config.eos_token_id = tokenizer.eos_token_id if do_sample: generation_config.temperature = args.temperature generation_config.top_p = args.top_p generation_config.top_k = args.top_k else: generation_config.temperature = None generation_config.top_p = None generation_config.top_k = None with torch.inference_mode(): output = model.generate( **inputs, generation_config=generation_config, max_new_tokens=args.max_new_tokens, repetition_penalty=args.repetition_penalty, ) new_tokens = output[:, inputs["input_ids"].shape[-1]:] return [ response.strip() for response in tokenizer.batch_decode(new_tokens, skip_special_tokens=True) ] def main(): args = parse_args() rank, local_rank, world_size = init_distributed() output_dir = Path(args.output_dir) shard_dir = output_dir / "shards" shard_dir.mkdir(parents=True, exist_ok=True) data_path = resolve_data_file(args.data_root, args.eval_file) records = load_tsqa_records(data_path) records = balanced_limit(records, args.max_eval_samples) shard_records = records[rank::world_size] if rank == 0: print(f"Dataset: {data_path}") print(f"Total samples: {len(records)}") print(f"World size: {world_size}") print(f"Per-device eval batch size: {args.eval_batch_size}") print(f"Maximum global eval batch size: {args.eval_batch_size * world_size}") print(f"Output dir: {output_dir}") model, tokenizer, device = load_model_and_tokenizer(args, local_rank) results = [] batch_size = max(1, args.eval_batch_size) batch_starts = range(0, len(shard_records), batch_size) for start in tqdm( batch_starts, total=(len(shard_records) + batch_size - 1) // batch_size, desc=f"rank {rank}", disable=rank != 0, ): batch = shard_records[start:start + batch_size] prompts = [ build_prompt(record["question"], record.get("time_series") or []) for record in batch ] responses = generate_responses(model, tokenizer, device, prompts, args) for record, response in zip(batch, responses): results.append(build_result_row(record, response)) shard_path = shard_dir / f"predictions.rank{rank}.jsonl" write_jsonl(results, shard_path) if world_size > 1: dist.barrier() if rank == 0: merged = [] for shard_rank in range(world_size): path = shard_dir / f"predictions.rank{shard_rank}.jsonl" with open(path, "r", encoding="utf-8") as handle: for line in handle: line = line.strip() if line: merged.append(json.loads(line)) merged.sort(key=lambda row: str(row.get("id"))) predictions_path = output_dir / "predictions.jsonl" metrics_path = output_dir / "metrics.json" write_jsonl(merged, predictions_path) metrics = compute_tsqa_metrics(merged) atomic_write_json(metrics, metrics_path) print(json.dumps(metrics, ensure_ascii=False, indent=2)) print(f"Saved predictions: {predictions_path}") print(f"Saved metrics: {metrics_path}") if world_size > 1: dist.barrier() dist.destroy_process_group() if __name__ == "__main__": main()