"""Time-MQA TSQA dataset adapter for ChatTime. This mirrors ``rats40k_adapter`` but targets the multi-task Time-MQA TSQA benchmark (``/mnt/share01/sqk/datasets/Time-MQA_TSQA/tmp``), which contains four source groups -- anomaly_detection, classification, forecasting and open_ended -- with free-text answers. Design choices: * ChatTime has no chat template, so prompts use ChatTime's native analysis template (``utils.prompt.getPrompt``): the natural-language ``question`` goes into ``### Instruction`` (kept verbatim, including any embedded numbers so forecasting keeps its absolute scale) and the ChatTime-serialised series goes into ``### Input`` so the model still receives its native time-series tokens. * The free-text ``answer`` is the ``### Response`` target. * Some context-enhanced forecasting questions are far longer than ChatTime's 4096-token context. Callers left-truncate the prompt (keeping the series, the trailing ask and the full response) rather than failing. * Evaluation reuses the canonical Time-MQA evaluator (``MQA/data_utils.compute_group_metrics``) so ChatTime numbers are directly comparable to the other TSQA baselines. """ import json import os import sys from collections import defaultdict from pathlib import Path import numpy as np CHAT_TIME_DIR = Path(__file__).resolve().parents[1] MQA_DIR = Path(os.environ.get("MQA_DIR", "/mnt/share01/sqk/MQA")) if str(CHAT_TIME_DIR) not in sys.path: sys.path.insert(0, str(CHAT_TIME_DIR)) from utils.prompt import getPrompt # noqa: E402 from utils.tools import Discretizer, Serializer # noqa: E402 DEFAULT_DATA_ROOT = "/mnt/share01/sqk/datasets/Time-MQA_TSQA/tmp" SOURCE_ORDER = ("anomaly_detection", "classification", "forecasting", "open_ended") # Fields carried into the saved prediction rows so the Time-MQA evaluator and # manual inspection have everything they need. RESULT_FIELDS = ( "id", "figure_path", "application_domain", "task_type", "source_type", "question_format", ) def resolve_data_file(data_root, data_file): """Resolve a split file: absolute override, or relative to ``data_root``.""" path = Path(data_file) if not path.is_absolute(): path = Path(data_root) / path if not path.is_file(): raise FileNotFoundError(f"TSQA JSONL not found: {path}") return str(path) def load_tsqa_records(path): """Load TSQA JSONL records, dropping rows without both a question and answer.""" records = [] with open(path, "r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue record = json.loads(line) if not str(record.get("question") or "").strip(): continue if not str(record.get("answer") or "").strip(): continue records.append(record) return records def balanced_limit(records, max_samples): """Round-robin a balanced subset across source groups (for smoke tests).""" if not max_samples or max_samples <= 0 or len(records) <= max_samples: return records grouped = defaultdict(list) for record in records: grouped[str(record.get("source_type") or "unknown")].append(record) names = [name for name in SOURCE_ORDER if grouped.get(name)] names.extend(sorted(name for name in grouped if name not in names)) selected = [] offsets = defaultdict(int) while len(selected) < max_samples: added = False for name in names: offset = offsets[name] if offset < len(grouped[name]): selected.append(grouped[name][offset]) offsets[name] += 1 added = True if len(selected) == max_samples: break if not added: break return selected def _to_float_array(series): arr = np.asarray(series, dtype=float).reshape(-1) if arr.size == 0: return arr if not np.isfinite(arr).all(): finite = np.isfinite(arr) fill = float(np.median(arr[finite])) if finite.any() else 0.0 arr = np.where(np.isfinite(arr), arr, fill) return arr def serialize_series(series): """Discretise + serialise a univariate series into ChatTime time tokens.""" arr = _to_float_array(series) if arr.size == 0: return "" discretizer = Discretizer() serializer = Serializer() return serializer.serialize(discretizer.discretize(arr)) def build_prompt(question, series, response=None): """Build a ChatTime prompt: question -> Instruction, series -> Input.""" question = str(question or "").strip() response = "" if response is None else str(response) serialized = serialize_series(series) if serialized: return getPrompt( flag="analysis", instruction=question, input=serialized, response=response, ) # No usable series: fall back to a plain instruction prompt. return getPrompt( flag="general", instruction=question, input="", response=response, ) def build_result_row(record, prediction): """Assemble a prediction row in the Time-MQA evaluator schema.""" row = {field: str(record.get(field) or "") for field in RESULT_FIELDS} row["question"] = str(record.get("question") or "") row["answer"] = str(record.get("answer") or "") row["prediction"] = "" if prediction is None else str(prediction).strip() return row def compute_tsqa_metrics(rows): """Per-group TSQA metrics, reusing the canonical Time-MQA evaluator.""" if str(MQA_DIR) not in sys.path: sys.path.insert(0, str(MQA_DIR)) from data_utils import compute_group_metrics, compute_text_metrics rows = list(rows) metrics = {"by_group": compute_group_metrics(rows)} metrics["text_overall"] = compute_text_metrics(rows) counts = defaultdict(int) for row in rows: counts[str(row.get("source_type") or "unknown")] += 1 metrics["num_samples"] = len(rows) metrics["counts_by_group"] = dict(counts) return metrics def atomic_write_json(obj, path): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(path.suffix + f".tmp.{os.getpid()}") with open(tmp, "w", encoding="utf-8") as handle: json.dump(obj, handle, indent=2, ensure_ascii=False) os.replace(tmp, path) def write_jsonl(rows, path): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(path.suffix + f".tmp.{os.getpid()}") with open(tmp, "w", encoding="utf-8") as handle: for row in rows: handle.write(json.dumps(row, ensure_ascii=False) + "\n") os.replace(tmp, path)