sutama's picture
Upload CabinLavatoryPrediction LoRA adapter, checkpoint, code, and evaluation artifacts
e74a796 verified
#!/usr/bin/env python3
import argparse
import json
import math
import re
from pathlib import Path
import numpy as np
import torch
from peft import PeftModel
from sklearn.metrics import accuracy_score, f1_score, mean_absolute_error, precision_recall_fscore_support
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
STRUCT_FIELDS = [
"current_behavior",
"is_transition",
"elapsed_seconds_in_current_behavior",
"estimated_remaining_seconds",
"full_remaining_seconds",
"expected_end_time",
"next_possible_behavior",
"stage_index",
"total_stages",
"sequence_so_far",
]
TIME_FIELDS = [
"elapsed_seconds_in_current_behavior",
"estimated_remaining_seconds",
"full_remaining_seconds",
"expected_end_time",
]
QA_FIELDS = ["occupied", "time_to_free_minutes", "used_areas", "is_abnormal"]
def read_jsonl(path, limit=None):
rows = []
with open(path, encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
rows.append(json.loads(line))
if limit and len(rows) >= limit:
break
return rows
def load_model(model_name, adapter_dir=None):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
if adapter_dir:
model = PeftModel.from_pretrained(model, adapter_dir)
model.eval()
return tokenizer, model
def render_prompt(tokenizer, messages):
try:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
except TypeError:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def json_candidates(text):
decoder = json.JSONDecoder()
for idx, char in enumerate(text):
if char != "{":
continue
try:
obj, _ = decoder.raw_decode(text[idx:])
except Exception:
continue
if isinstance(obj, dict):
yield obj
def parse_json_text(text, preferred_fields=None):
text = text.strip()
try:
return json.loads(text), None
except Exception:
pass
candidates = list(json_candidates(text))
if not candidates:
return None, "no_json_object"
if preferred_fields:
preferred = set(preferred_fields)
candidates.sort(key=lambda obj: len(preferred & set(obj.keys())), reverse=True)
return candidates[0], None
def generate_predictions(rows, tokenizer, model, max_new_tokens, batch_size, preferred_fields, max_input_tokens, pred_path=None):
records = []
pred_file = pred_path.open("w", encoding="utf-8") if pred_path else None
for start in range(0, len(rows), batch_size):
batch = rows[start : start + batch_size]
prompts = [render_prompt(tokenizer, row["messages"][:-1]) for row in batch]
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_input_tokens).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=None,
top_p=None,
pad_token_id=tokenizer.eos_token_id,
)
prompt_len = inputs["input_ids"].shape[1]
decoded = tokenizer.batch_decode(outputs[:, prompt_len:], skip_special_tokens=True)
for row, pred_text in zip(batch, decoded):
target_content = row["messages"][-1]["content"]
target = json.loads(target_content) if isinstance(target_content, str) else target_content
pred, error = parse_json_text(pred_text, preferred_fields)
record = {"target": target, "prediction": pred, "raw_prediction": pred_text, "parse_error": error}
records.append(record)
if pred_file:
pred_file.write(json.dumps(record, ensure_ascii=False, separators=(",", ":")) + "\n")
pred_file.flush()
print(f"generated {min(start + batch_size, len(rows))}/{len(rows)}", flush=True)
if pred_file:
pred_file.close()
return records
def safe_eq(a, b):
return a == b
def numeric_pairs(records, field):
y_true, y_pred = [], []
for rec in records:
pred = rec["prediction"]
if not isinstance(pred, dict):
continue
t, p = rec["target"].get(field), pred.get(field)
if isinstance(t, (int, float)) and isinstance(p, (int, float)) and math.isfinite(float(p)):
y_true.append(float(t))
y_pred.append(float(p))
return y_true, y_pred
def classification_metrics(records, field):
pairs = []
for rec in records:
pred = rec["prediction"]
if isinstance(pred, dict) and field in pred:
pairs.append((rec["target"].get(field), pred.get(field)))
if not pairs:
return {"accuracy": 0.0, "macro_f1": 0.0, "coverage": 0.0}
y_true, y_pred = zip(*pairs)
# sklearn cannot sort mixed labels such as None and str; normalize only for metric computation.
y_true = ["<NULL>" if value is None else str(value) for value in y_true]
y_pred = ["<NULL>" if value is None else str(value) for value in y_pred]
return {
"accuracy": float(accuracy_score(y_true, y_pred)),
"macro_f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
"coverage": len(pairs) / len(records),
}
def sequence_metrics(records):
exact = []
last = []
prefix = []
for rec in records:
pred = rec["prediction"]
if not isinstance(pred, dict):
continue
true_seq = [x.get("label") for x in rec["target"].get("sequence_so_far") or []]
pred_seq = [x.get("label") for x in pred.get("sequence_so_far") or [] if isinstance(x, dict)]
exact.append(true_seq == pred_seq)
last.append(bool(true_seq and pred_seq and true_seq[-1] == pred_seq[-1]))
prefix_len = min(len(true_seq), len(pred_seq))
prefix.append(sum(1 for i in range(prefix_len) if true_seq[i] == pred_seq[i]) / max(1, len(true_seq)))
return {
"sequence_exact_match": float(np.mean(exact)) if exact else 0.0,
"sequence_last_label_accuracy": float(np.mean(last)) if last else 0.0,
"sequence_prefix_label_match": float(np.mean(prefix)) if prefix else 0.0,
}
def evaluate_struct(records):
parsed = [r for r in records if isinstance(r["prediction"], dict)]
metrics = {
"num_examples": len(records),
"json_parse_rate": len(parsed) / max(1, len(records)),
"required_field_complete_rate": sum(all(f in r["prediction"] for f in STRUCT_FIELDS) for r in parsed) / max(1, len(records)),
}
for field in ["current_behavior", "next_possible_behavior", "is_transition", "stage_index", "total_stages"]:
cm = classification_metrics(records, field)
metrics[f"{field}_accuracy"] = cm["accuracy"]
if "behavior" in field or field == "is_transition":
metrics[f"{field}_macro_f1"] = cm["macro_f1"]
for field in TIME_FIELDS:
y_true, y_pred = numeric_pairs(records, field)
metrics[f"{field}_mae"] = float(mean_absolute_error(y_true, y_pred)) if y_true else None
metrics[f"{field}_coverage"] = len(y_true) / max(1, len(records))
metrics.update(sequence_metrics(records))
return metrics
def normalize_areas(value):
if not isinstance(value, list):
return set()
return {str(x) for x in value}
def evaluate_qa(records):
parsed = [r for r in records if isinstance(r["prediction"], dict)]
metrics = {
"num_examples": len(records),
"json_parse_rate": len(parsed) / max(1, len(records)),
"required_field_complete_rate": sum(all(f in r["prediction"] for f in QA_FIELDS) for r in parsed) / max(1, len(records)),
}
for field in ["occupied", "is_abnormal"]:
cm = classification_metrics(records, field)
metrics[f"{field}_accuracy"] = cm["accuracy"]
metrics[f"{field}_f1"] = cm["macro_f1"]
y_true, y_pred = numeric_pairs(records, "time_to_free_minutes")
metrics["time_to_free_minutes_mae"] = float(mean_absolute_error(y_true, y_pred)) if y_true else None
true_flat, pred_flat = [], []
labels = ["门", "马桶", "洗手池", "垃圾桶"]
for rec in records:
pred = rec["prediction"]
if not isinstance(pred, dict):
continue
t = normalize_areas(rec["target"].get("used_areas"))
p = normalize_areas(pred.get("used_areas"))
true_flat.extend([label in t for label in labels])
pred_flat.extend([label in p for label in labels])
if true_flat:
pr, rc, f1, _ = precision_recall_fscore_support(true_flat, pred_flat, average="binary", zero_division=0)
metrics["used_areas_micro_precision"] = float(pr)
metrics["used_areas_micro_recall"] = float(rc)
metrics["used_areas_micro_f1"] = float(f1)
return metrics
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", default="Qwen/Qwen3.5-9B")
parser.add_argument("--adapter-dir", default=None)
parser.add_argument("--input-file", default=None)
parser.add_argument("--predictions-file", default=None)
parser.add_argument("--task-type", choices=["struct", "qa"], required=True)
parser.add_argument("--output-dir", default="outputs")
parser.add_argument("--run-name", required=True)
parser.add_argument("--max-samples", type=int, default=None)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--max-new-tokens", type=int, default=1536)
parser.add_argument("--max-input-tokens", type=int, default=6144)
args = parser.parse_args()
out_root = Path(args.output_dir)
pred_dir = out_root / "predictions"
metric_dir = out_root / "metrics"
pred_dir.mkdir(parents=True, exist_ok=True)
metric_dir.mkdir(parents=True, exist_ok=True)
pred_path = pred_dir / f"{args.run_name}_{args.task_type}_predictions.jsonl"
if args.predictions_file:
records = read_jsonl(args.predictions_file, args.max_samples)
else:
if not args.input_file:
raise ValueError("--input-file is required unless --predictions-file is provided")
rows = read_jsonl(args.input_file, args.max_samples)
tokenizer, model = load_model(args.model_name, args.adapter_dir)
preferred_fields = STRUCT_FIELDS if args.task_type == "struct" else QA_FIELDS
records = generate_predictions(
rows, tokenizer, model, args.max_new_tokens, args.batch_size, preferred_fields, args.max_input_tokens, pred_path
)
metrics = evaluate_struct(records) if args.task_type == "struct" else evaluate_qa(records)
metric_payload = {
"run_name": args.run_name,
"task_type": args.task_type,
"input_file": args.input_file,
"predictions_file": args.predictions_file,
"metrics": metrics,
}
metric_path = metric_dir / f"{args.run_name}_{args.task_type}_metrics.json"
metric_path.write_text(json.dumps(metric_payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(metric_payload, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()