ChatTime / tsqa_adapter /tsqa_common.py
a12354's picture
Add files using upload-large-folder tool
8d2b389 verified
Raw
History Blame Contribute Delete
6.82 kB
"""Time-MQA TSQA dataset adapter for ChatTime.
This mirrors ``rats40k_adapter`` but targets the multi-task Time-MQA TSQA
benchmark (``/mnt/share01/sqk/datasets/Time-MQA_TSQA/tmp``), which contains four
source groups -- anomaly_detection, classification, forecasting and open_ended
-- with free-text answers.
Design choices:
* ChatTime has no chat template, so prompts use ChatTime's native analysis
template (``utils.prompt.getPrompt``): the natural-language ``question`` goes
into ``### Instruction`` (kept verbatim, including any embedded numbers so
forecasting keeps its absolute scale) and the ChatTime-serialised series goes
into ``### Input`` so the model still receives its native time-series tokens.
* The free-text ``answer`` is the ``### Response`` target.
* Some context-enhanced forecasting questions are far longer than ChatTime's
4096-token context. Callers left-truncate the prompt (keeping the series, the
trailing ask and the full response) rather than failing.
* Evaluation reuses the canonical Time-MQA evaluator
(``MQA/data_utils.compute_group_metrics``) so ChatTime numbers are directly
comparable to the other TSQA baselines.
"""
import json
import os
import sys
from collections import defaultdict
from pathlib import Path
import numpy as np
CHAT_TIME_DIR = Path(__file__).resolve().parents[1]
MQA_DIR = Path(os.environ.get("MQA_DIR", "/mnt/share01/sqk/MQA"))
if str(CHAT_TIME_DIR) not in sys.path:
sys.path.insert(0, str(CHAT_TIME_DIR))
from utils.prompt import getPrompt # noqa: E402
from utils.tools import Discretizer, Serializer # noqa: E402
DEFAULT_DATA_ROOT = "/mnt/share01/sqk/datasets/Time-MQA_TSQA/tmp"
SOURCE_ORDER = ("anomaly_detection", "classification", "forecasting", "open_ended")
# Fields carried into the saved prediction rows so the Time-MQA evaluator and
# manual inspection have everything they need.
RESULT_FIELDS = (
"id",
"figure_path",
"application_domain",
"task_type",
"source_type",
"question_format",
)
def resolve_data_file(data_root, data_file):
"""Resolve a split file: absolute override, or relative to ``data_root``."""
path = Path(data_file)
if not path.is_absolute():
path = Path(data_root) / path
if not path.is_file():
raise FileNotFoundError(f"TSQA JSONL not found: {path}")
return str(path)
def load_tsqa_records(path):
"""Load TSQA JSONL records, dropping rows without both a question and answer."""
records = []
with open(path, "r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
record = json.loads(line)
if not str(record.get("question") or "").strip():
continue
if not str(record.get("answer") or "").strip():
continue
records.append(record)
return records
def balanced_limit(records, max_samples):
"""Round-robin a balanced subset across source groups (for smoke tests)."""
if not max_samples or max_samples <= 0 or len(records) <= max_samples:
return records
grouped = defaultdict(list)
for record in records:
grouped[str(record.get("source_type") or "unknown")].append(record)
names = [name for name in SOURCE_ORDER if grouped.get(name)]
names.extend(sorted(name for name in grouped if name not in names))
selected = []
offsets = defaultdict(int)
while len(selected) < max_samples:
added = False
for name in names:
offset = offsets[name]
if offset < len(grouped[name]):
selected.append(grouped[name][offset])
offsets[name] += 1
added = True
if len(selected) == max_samples:
break
if not added:
break
return selected
def _to_float_array(series):
arr = np.asarray(series, dtype=float).reshape(-1)
if arr.size == 0:
return arr
if not np.isfinite(arr).all():
finite = np.isfinite(arr)
fill = float(np.median(arr[finite])) if finite.any() else 0.0
arr = np.where(np.isfinite(arr), arr, fill)
return arr
def serialize_series(series):
"""Discretise + serialise a univariate series into ChatTime time tokens."""
arr = _to_float_array(series)
if arr.size == 0:
return ""
discretizer = Discretizer()
serializer = Serializer()
return serializer.serialize(discretizer.discretize(arr))
def build_prompt(question, series, response=None):
"""Build a ChatTime prompt: question -> Instruction, series -> Input."""
question = str(question or "").strip()
response = "" if response is None else str(response)
serialized = serialize_series(series)
if serialized:
return getPrompt(
flag="analysis",
instruction=question,
input=serialized,
response=response,
)
# No usable series: fall back to a plain instruction prompt.
return getPrompt(
flag="general",
instruction=question,
input="",
response=response,
)
def build_result_row(record, prediction):
"""Assemble a prediction row in the Time-MQA evaluator schema."""
row = {field: str(record.get(field) or "") for field in RESULT_FIELDS}
row["question"] = str(record.get("question") or "")
row["answer"] = str(record.get("answer") or "")
row["prediction"] = "" if prediction is None else str(prediction).strip()
return row
def compute_tsqa_metrics(rows):
"""Per-group TSQA metrics, reusing the canonical Time-MQA evaluator."""
if str(MQA_DIR) not in sys.path:
sys.path.insert(0, str(MQA_DIR))
from data_utils import compute_group_metrics, compute_text_metrics
rows = list(rows)
metrics = {"by_group": compute_group_metrics(rows)}
metrics["text_overall"] = compute_text_metrics(rows)
counts = defaultdict(int)
for row in rows:
counts[str(row.get("source_type") or "unknown")] += 1
metrics["num_samples"] = len(rows)
metrics["counts_by_group"] = dict(counts)
return metrics
def atomic_write_json(obj, path):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + f".tmp.{os.getpid()}")
with open(tmp, "w", encoding="utf-8") as handle:
json.dump(obj, handle, indent=2, ensure_ascii=False)
os.replace(tmp, path)
def write_jsonl(rows, path):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + f".tmp.{os.getpid()}")
with open(tmp, "w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
os.replace(tmp, path)