CabinLavatoryPrediction / code /prepare_data.py
sutama's picture
Upload CabinLavatoryPrediction LoRA adapter, checkpoint, code, and evaluation artifacts
e74a796 verified
#!/usr/bin/env python3
import argparse
import json
import math
from collections import defaultdict
from pathlib import Path
LABEL_MAP = {"反复折返": "折返"}
TIME_FIELDS = [
"elapsed_seconds_in_current_behavior",
"estimated_remaining_seconds",
"full_remaining_seconds",
"expected_end_time",
]
OUTPUT_FIELDS = [
"current_behavior",
"is_transition",
"elapsed_seconds_in_current_behavior",
"estimated_remaining_seconds",
"full_remaining_seconds",
"expected_end_time",
"next_possible_behavior",
"stage_index",
"total_stages",
"sequence_so_far",
]
def normalize_label(value):
return LABEL_MAP.get(value, value)
def normalize_tree(value):
if isinstance(value, dict):
out = {}
for key, item in value.items():
if key in {"current_behavior", "next_possible_behavior", "label"} and isinstance(item, str):
out[key] = normalize_label(item)
elif key in {"full_sequence_order", "label_space"} and isinstance(item, list):
out[key] = [normalize_label(x) if isinstance(x, str) else x for x in item]
else:
out[key] = normalize_tree(item)
return out
if isinstance(value, list):
return [normalize_tree(item) for item in value]
if isinstance(value, str):
return normalize_label(value)
return value
def read_chat_jsonl(path):
with path.open(encoding="utf-8") as f:
for line_no, line in enumerate(f, 1):
if not line.strip():
continue
try:
obj = json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"{path}:{line_no} is not valid JSONL: {exc}") from exc
yield obj
def get_assistant_json(example):
content = example["messages"][-1]["content"]
return json.loads(content) if isinstance(content, str) else content
def set_assistant_json(example, data):
example["messages"][-1]["content"] = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
def clean_example(example):
example = normalize_tree(example)
for message in example.get("messages", []):
content = message.get("content")
if isinstance(content, str):
try:
parsed = json.loads(content)
except Exception:
continue
message["content"] = json.dumps(normalize_tree(parsed), ensure_ascii=False, separators=(",", ":"))
assistant = get_assistant_json(example)
assistant = normalize_tree(assistant)
set_assistant_json(example, assistant)
return example
def write_jsonl(path, rows):
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False, separators=(",", ":")) + "\n")
def percentile(values, q):
if not values:
return 0.0
values = sorted(values)
idx = (len(values) - 1) * q
lo = math.floor(idx)
hi = math.ceil(idx)
if lo == hi:
return float(values[lo])
return float(values[lo] * (hi - idx) + values[hi] * (idx - lo))
def build_thresholds(clean_train):
by_label = defaultdict(list)
for ex in clean_train:
assistant = get_assistant_json(ex)
label = assistant.get("current_behavior")
elapsed = assistant.get("elapsed_seconds_in_current_behavior")
if isinstance(elapsed, (int, float)):
by_label[label].append(float(elapsed))
thresholds = {label: max(5.0, percentile(vals, 0.95)) for label, vals in by_label.items()}
thresholds["__default__"] = max(30.0, percentile([v for vals in by_label.values() for v in vals], 0.95))
return thresholds
def used_areas(sequence):
mapping = {
"门": {"进入", "门锁", "靠近门", "离开"},
"马桶": {"靠近马桶", "马桶盖", "马桶垫纸", "坐下", "坐用马桶", "站用马桶", "卷筒厕纸", "起身", "冲水"},
"洗手池": {"靠近洗手池", "洗手", "刷牙"},
"垃圾桶": {"垃圾桶"},
}
labels = {item.get("label") for item in sequence or []}
return [area for area, area_labels in mapping.items() if labels & area_labels]
def qa_target(assistant, thresholds):
sequence = assistant.get("sequence_so_far") or []
current = assistant.get("current_behavior")
full_remaining = assistant.get("full_remaining_seconds")
elapsed = assistant.get("elapsed_seconds_in_current_behavior")
occupied = bool(sequence) and current != "离开" and (not isinstance(full_remaining, (int, float)) or full_remaining > 0)
time_to_free_minutes = round(max(0.0, float(full_remaining or 0.0)) / 60.0, 2)
threshold = thresholds.get(current, thresholds.get("__default__", 120.0))
frequent_switching = len(sequence) >= 12
long_current = isinstance(elapsed, (int, float)) and elapsed > threshold
return {
"occupied": occupied,
"time_to_free_minutes": time_to_free_minutes,
"used_areas": used_areas(sequence),
"is_abnormal": bool(long_current or frequent_switching),
}
def make_qa_example(source_example, thresholds):
assistant = get_assistant_json(source_example)
compact = {key: assistant.get(key) for key in OUTPUT_FIELDS}
user_payload = {
"task": "qa_from_behavior_json",
"model_output_json": compact,
"questions": [
"卫生间是否被占用?",
"预计多长时间之后卫生间会空出?",
"卫生间内哪些区域被使用过?",
"卫生间是否存在异常?",
],
}
return {
"messages": [
{
"role": "system",
"content": "你是卫生间状态问答模型。请只根据输入的行为分析 JSON 回答,并只输出 JSON:occupied(bool), time_to_free_minutes(number), used_areas(array), is_abnormal(bool)。",
},
{"role": "user", "content": json.dumps(user_payload, ensure_ascii=False, separators=(",", ":"))},
{"role": "assistant", "content": json.dumps(qa_target(assistant, thresholds), ensure_ascii=False, separators=(",", ":"))},
]
}
def dataset_stats(rows):
labels = defaultdict(int)
sample_ids = set()
for ex in rows:
assistant = get_assistant_json(ex)
labels[assistant.get("current_behavior")] += 1
try:
sample_ids.add(json.loads(ex["messages"][1]["content"]).get("sample_id"))
except Exception:
pass
return {"num_examples": len(rows), "num_sample_ids": len(sample_ids), "label_counts": dict(sorted(labels.items()))}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--train", default="train.jsonl")
parser.add_argument("--val", default="val.jsonl")
parser.add_argument("--out-dir", default="data/processed")
args = parser.parse_args()
out_dir = Path(args.out_dir)
train = [clean_example(ex) for ex in read_chat_jsonl(Path(args.train))]
val = [clean_example(ex) for ex in read_chat_jsonl(Path(args.val))]
thresholds = build_thresholds(train)
train_qa = [make_qa_example(ex, thresholds) for ex in train]
val_qa = [make_qa_example(ex, thresholds) for ex in val]
write_jsonl(out_dir / "train_struct.jsonl", train)
write_jsonl(out_dir / "val_struct.jsonl", val)
write_jsonl(out_dir / "train_qa.jsonl", train_qa)
write_jsonl(out_dir / "val_qa.jsonl", val_qa)
write_jsonl(out_dir / "train_mixed.jsonl", train + train_qa)
write_jsonl(out_dir / "val_mixed.jsonl", val + val_qa)
summary = {
"normalization": LABEL_MAP,
"train_struct": dataset_stats(train),
"val_struct": dataset_stats(val),
"train_qa": {"num_examples": len(train_qa)},
"val_qa": {"num_examples": len(val_qa)},
"abnormal_elapsed_thresholds_p95": thresholds,
"qa_schema": ["occupied", "time_to_free_minutes", "used_areas", "is_abnormal"],
}
(out_dir / "summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(summary, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()