| """Time-MQA TSQA dataset adapter for ChatTime. |
| |
| This mirrors ``rats40k_adapter`` but targets the multi-task Time-MQA TSQA |
| benchmark (``/mnt/share01/sqk/datasets/Time-MQA_TSQA/tmp``), which contains four |
| source groups -- anomaly_detection, classification, forecasting and open_ended |
| -- with free-text answers. |
| |
| Design choices: |
| |
| * ChatTime has no chat template, so prompts use ChatTime's native analysis |
| template (``utils.prompt.getPrompt``): the natural-language ``question`` goes |
| into ``### Instruction`` (kept verbatim, including any embedded numbers so |
| forecasting keeps its absolute scale) and the ChatTime-serialised series goes |
| into ``### Input`` so the model still receives its native time-series tokens. |
| * The free-text ``answer`` is the ``### Response`` target. |
| * Some context-enhanced forecasting questions are far longer than ChatTime's |
| 4096-token context. Callers left-truncate the prompt (keeping the series, the |
| trailing ask and the full response) rather than failing. |
| * Evaluation reuses the canonical Time-MQA evaluator |
| (``MQA/data_utils.compute_group_metrics``) so ChatTime numbers are directly |
| comparable to the other TSQA baselines. |
| """ |
|
|
| import json |
| import os |
| import sys |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
|
|
| CHAT_TIME_DIR = Path(__file__).resolve().parents[1] |
| MQA_DIR = Path(os.environ.get("MQA_DIR", "/mnt/share01/sqk/MQA")) |
| if str(CHAT_TIME_DIR) not in sys.path: |
| sys.path.insert(0, str(CHAT_TIME_DIR)) |
|
|
| from utils.prompt import getPrompt |
| from utils.tools import Discretizer, Serializer |
|
|
|
|
| DEFAULT_DATA_ROOT = "/mnt/share01/sqk/datasets/Time-MQA_TSQA/tmp" |
| SOURCE_ORDER = ("anomaly_detection", "classification", "forecasting", "open_ended") |
|
|
| |
| |
| RESULT_FIELDS = ( |
| "id", |
| "figure_path", |
| "application_domain", |
| "task_type", |
| "source_type", |
| "question_format", |
| ) |
|
|
|
|
| def resolve_data_file(data_root, data_file): |
| """Resolve a split file: absolute override, or relative to ``data_root``.""" |
| path = Path(data_file) |
| if not path.is_absolute(): |
| path = Path(data_root) / path |
| if not path.is_file(): |
| raise FileNotFoundError(f"TSQA JSONL not found: {path}") |
| return str(path) |
|
|
|
|
| def load_tsqa_records(path): |
| """Load TSQA JSONL records, dropping rows without both a question and answer.""" |
| records = [] |
| with open(path, "r", encoding="utf-8") as handle: |
| for line in handle: |
| line = line.strip() |
| if not line: |
| continue |
| record = json.loads(line) |
| if not str(record.get("question") or "").strip(): |
| continue |
| if not str(record.get("answer") or "").strip(): |
| continue |
| records.append(record) |
| return records |
|
|
|
|
| def balanced_limit(records, max_samples): |
| """Round-robin a balanced subset across source groups (for smoke tests).""" |
| if not max_samples or max_samples <= 0 or len(records) <= max_samples: |
| return records |
|
|
| grouped = defaultdict(list) |
| for record in records: |
| grouped[str(record.get("source_type") or "unknown")].append(record) |
|
|
| names = [name for name in SOURCE_ORDER if grouped.get(name)] |
| names.extend(sorted(name for name in grouped if name not in names)) |
|
|
| selected = [] |
| offsets = defaultdict(int) |
| while len(selected) < max_samples: |
| added = False |
| for name in names: |
| offset = offsets[name] |
| if offset < len(grouped[name]): |
| selected.append(grouped[name][offset]) |
| offsets[name] += 1 |
| added = True |
| if len(selected) == max_samples: |
| break |
| if not added: |
| break |
| return selected |
|
|
|
|
| def _to_float_array(series): |
| arr = np.asarray(series, dtype=float).reshape(-1) |
| if arr.size == 0: |
| return arr |
| if not np.isfinite(arr).all(): |
| finite = np.isfinite(arr) |
| fill = float(np.median(arr[finite])) if finite.any() else 0.0 |
| arr = np.where(np.isfinite(arr), arr, fill) |
| return arr |
|
|
|
|
| def serialize_series(series): |
| """Discretise + serialise a univariate series into ChatTime time tokens.""" |
| arr = _to_float_array(series) |
| if arr.size == 0: |
| return "" |
| discretizer = Discretizer() |
| serializer = Serializer() |
| return serializer.serialize(discretizer.discretize(arr)) |
|
|
|
|
| def build_prompt(question, series, response=None): |
| """Build a ChatTime prompt: question -> Instruction, series -> Input.""" |
| question = str(question or "").strip() |
| response = "" if response is None else str(response) |
| serialized = serialize_series(series) |
| if serialized: |
| return getPrompt( |
| flag="analysis", |
| instruction=question, |
| input=serialized, |
| response=response, |
| ) |
| |
| return getPrompt( |
| flag="general", |
| instruction=question, |
| input="", |
| response=response, |
| ) |
|
|
|
|
| def build_result_row(record, prediction): |
| """Assemble a prediction row in the Time-MQA evaluator schema.""" |
| row = {field: str(record.get(field) or "") for field in RESULT_FIELDS} |
| row["question"] = str(record.get("question") or "") |
| row["answer"] = str(record.get("answer") or "") |
| row["prediction"] = "" if prediction is None else str(prediction).strip() |
| return row |
|
|
|
|
| def compute_tsqa_metrics(rows): |
| """Per-group TSQA metrics, reusing the canonical Time-MQA evaluator.""" |
| if str(MQA_DIR) not in sys.path: |
| sys.path.insert(0, str(MQA_DIR)) |
| from data_utils import compute_group_metrics, compute_text_metrics |
|
|
| rows = list(rows) |
| metrics = {"by_group": compute_group_metrics(rows)} |
| metrics["text_overall"] = compute_text_metrics(rows) |
| counts = defaultdict(int) |
| for row in rows: |
| counts[str(row.get("source_type") or "unknown")] += 1 |
| metrics["num_samples"] = len(rows) |
| metrics["counts_by_group"] = dict(counts) |
| return metrics |
|
|
|
|
| def atomic_write_json(obj, path): |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| tmp = path.with_suffix(path.suffix + f".tmp.{os.getpid()}") |
| with open(tmp, "w", encoding="utf-8") as handle: |
| json.dump(obj, handle, indent=2, ensure_ascii=False) |
| os.replace(tmp, path) |
|
|
|
|
| def write_jsonl(rows, path): |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| tmp = path.with_suffix(path.suffix + f".tmp.{os.getpid()}") |
| with open(tmp, "w", encoding="utf-8") as handle: |
| for row in rows: |
| handle.write(json.dumps(row, ensure_ascii=False) + "\n") |
| os.replace(tmp, path) |
|
|