| |
| import argparse |
| import json |
| import math |
| import re |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from peft import PeftModel |
| from sklearn.metrics import accuracy_score, f1_score, mean_absolute_error, precision_recall_fscore_support |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
|
|
| STRUCT_FIELDS = [ |
| "current_behavior", |
| "is_transition", |
| "elapsed_seconds_in_current_behavior", |
| "estimated_remaining_seconds", |
| "full_remaining_seconds", |
| "expected_end_time", |
| "next_possible_behavior", |
| "stage_index", |
| "total_stages", |
| "sequence_so_far", |
| ] |
| TIME_FIELDS = [ |
| "elapsed_seconds_in_current_behavior", |
| "estimated_remaining_seconds", |
| "full_remaining_seconds", |
| "expected_end_time", |
| ] |
| QA_FIELDS = ["occupied", "time_to_free_minutes", "used_areas", "is_abnormal"] |
|
|
|
|
| def read_jsonl(path, limit=None): |
| rows = [] |
| with open(path, encoding="utf-8") as f: |
| for line in f: |
| if not line.strip(): |
| continue |
| rows.append(json.loads(line)) |
| if limit and len(rows) >= limit: |
| break |
| return rows |
|
|
|
|
| def load_model(model_name, adapter_dir=None): |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "left" |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| quantization_config=bnb_config, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| ) |
| if adapter_dir: |
| model = PeftModel.from_pretrained(model, adapter_dir) |
| model.eval() |
| return tokenizer, model |
|
|
|
|
| def render_prompt(tokenizer, messages): |
| try: |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) |
| except TypeError: |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
| def json_candidates(text): |
| decoder = json.JSONDecoder() |
| for idx, char in enumerate(text): |
| if char != "{": |
| continue |
| try: |
| obj, _ = decoder.raw_decode(text[idx:]) |
| except Exception: |
| continue |
| if isinstance(obj, dict): |
| yield obj |
|
|
|
|
| def parse_json_text(text, preferred_fields=None): |
| text = text.strip() |
| try: |
| return json.loads(text), None |
| except Exception: |
| pass |
| candidates = list(json_candidates(text)) |
| if not candidates: |
| return None, "no_json_object" |
| if preferred_fields: |
| preferred = set(preferred_fields) |
| candidates.sort(key=lambda obj: len(preferred & set(obj.keys())), reverse=True) |
| return candidates[0], None |
|
|
|
|
| def generate_predictions(rows, tokenizer, model, max_new_tokens, batch_size, preferred_fields, max_input_tokens, pred_path=None): |
| records = [] |
| pred_file = pred_path.open("w", encoding="utf-8") if pred_path else None |
| for start in range(0, len(rows), batch_size): |
| batch = rows[start : start + batch_size] |
| prompts = [render_prompt(tokenizer, row["messages"][:-1]) for row in batch] |
| inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_input_tokens).to(model.device) |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| temperature=None, |
| top_p=None, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| prompt_len = inputs["input_ids"].shape[1] |
| decoded = tokenizer.batch_decode(outputs[:, prompt_len:], skip_special_tokens=True) |
| for row, pred_text in zip(batch, decoded): |
| target_content = row["messages"][-1]["content"] |
| target = json.loads(target_content) if isinstance(target_content, str) else target_content |
| pred, error = parse_json_text(pred_text, preferred_fields) |
| record = {"target": target, "prediction": pred, "raw_prediction": pred_text, "parse_error": error} |
| records.append(record) |
| if pred_file: |
| pred_file.write(json.dumps(record, ensure_ascii=False, separators=(",", ":")) + "\n") |
| pred_file.flush() |
| print(f"generated {min(start + batch_size, len(rows))}/{len(rows)}", flush=True) |
| if pred_file: |
| pred_file.close() |
| return records |
|
|
|
|
| def safe_eq(a, b): |
| return a == b |
|
|
|
|
| def numeric_pairs(records, field): |
| y_true, y_pred = [], [] |
| for rec in records: |
| pred = rec["prediction"] |
| if not isinstance(pred, dict): |
| continue |
| t, p = rec["target"].get(field), pred.get(field) |
| if isinstance(t, (int, float)) and isinstance(p, (int, float)) and math.isfinite(float(p)): |
| y_true.append(float(t)) |
| y_pred.append(float(p)) |
| return y_true, y_pred |
|
|
|
|
| def classification_metrics(records, field): |
| pairs = [] |
| for rec in records: |
| pred = rec["prediction"] |
| if isinstance(pred, dict) and field in pred: |
| pairs.append((rec["target"].get(field), pred.get(field))) |
| if not pairs: |
| return {"accuracy": 0.0, "macro_f1": 0.0, "coverage": 0.0} |
| y_true, y_pred = zip(*pairs) |
| |
| y_true = ["<NULL>" if value is None else str(value) for value in y_true] |
| y_pred = ["<NULL>" if value is None else str(value) for value in y_pred] |
| return { |
| "accuracy": float(accuracy_score(y_true, y_pred)), |
| "macro_f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)), |
| "coverage": len(pairs) / len(records), |
| } |
|
|
|
|
| def sequence_metrics(records): |
| exact = [] |
| last = [] |
| prefix = [] |
| for rec in records: |
| pred = rec["prediction"] |
| if not isinstance(pred, dict): |
| continue |
| true_seq = [x.get("label") for x in rec["target"].get("sequence_so_far") or []] |
| pred_seq = [x.get("label") for x in pred.get("sequence_so_far") or [] if isinstance(x, dict)] |
| exact.append(true_seq == pred_seq) |
| last.append(bool(true_seq and pred_seq and true_seq[-1] == pred_seq[-1])) |
| prefix_len = min(len(true_seq), len(pred_seq)) |
| prefix.append(sum(1 for i in range(prefix_len) if true_seq[i] == pred_seq[i]) / max(1, len(true_seq))) |
| return { |
| "sequence_exact_match": float(np.mean(exact)) if exact else 0.0, |
| "sequence_last_label_accuracy": float(np.mean(last)) if last else 0.0, |
| "sequence_prefix_label_match": float(np.mean(prefix)) if prefix else 0.0, |
| } |
|
|
|
|
| def evaluate_struct(records): |
| parsed = [r for r in records if isinstance(r["prediction"], dict)] |
| metrics = { |
| "num_examples": len(records), |
| "json_parse_rate": len(parsed) / max(1, len(records)), |
| "required_field_complete_rate": sum(all(f in r["prediction"] for f in STRUCT_FIELDS) for r in parsed) / max(1, len(records)), |
| } |
| for field in ["current_behavior", "next_possible_behavior", "is_transition", "stage_index", "total_stages"]: |
| cm = classification_metrics(records, field) |
| metrics[f"{field}_accuracy"] = cm["accuracy"] |
| if "behavior" in field or field == "is_transition": |
| metrics[f"{field}_macro_f1"] = cm["macro_f1"] |
| for field in TIME_FIELDS: |
| y_true, y_pred = numeric_pairs(records, field) |
| metrics[f"{field}_mae"] = float(mean_absolute_error(y_true, y_pred)) if y_true else None |
| metrics[f"{field}_coverage"] = len(y_true) / max(1, len(records)) |
| metrics.update(sequence_metrics(records)) |
| return metrics |
|
|
|
|
| def normalize_areas(value): |
| if not isinstance(value, list): |
| return set() |
| return {str(x) for x in value} |
|
|
|
|
| def evaluate_qa(records): |
| parsed = [r for r in records if isinstance(r["prediction"], dict)] |
| metrics = { |
| "num_examples": len(records), |
| "json_parse_rate": len(parsed) / max(1, len(records)), |
| "required_field_complete_rate": sum(all(f in r["prediction"] for f in QA_FIELDS) for r in parsed) / max(1, len(records)), |
| } |
| for field in ["occupied", "is_abnormal"]: |
| cm = classification_metrics(records, field) |
| metrics[f"{field}_accuracy"] = cm["accuracy"] |
| metrics[f"{field}_f1"] = cm["macro_f1"] |
| y_true, y_pred = numeric_pairs(records, "time_to_free_minutes") |
| metrics["time_to_free_minutes_mae"] = float(mean_absolute_error(y_true, y_pred)) if y_true else None |
| true_flat, pred_flat = [], [] |
| labels = ["门", "马桶", "洗手池", "垃圾桶"] |
| for rec in records: |
| pred = rec["prediction"] |
| if not isinstance(pred, dict): |
| continue |
| t = normalize_areas(rec["target"].get("used_areas")) |
| p = normalize_areas(pred.get("used_areas")) |
| true_flat.extend([label in t for label in labels]) |
| pred_flat.extend([label in p for label in labels]) |
| if true_flat: |
| pr, rc, f1, _ = precision_recall_fscore_support(true_flat, pred_flat, average="binary", zero_division=0) |
| metrics["used_areas_micro_precision"] = float(pr) |
| metrics["used_areas_micro_recall"] = float(rc) |
| metrics["used_areas_micro_f1"] = float(f1) |
| return metrics |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model-name", default="Qwen/Qwen3.5-9B") |
| parser.add_argument("--adapter-dir", default=None) |
| parser.add_argument("--input-file", default=None) |
| parser.add_argument("--predictions-file", default=None) |
| parser.add_argument("--task-type", choices=["struct", "qa"], required=True) |
| parser.add_argument("--output-dir", default="outputs") |
| parser.add_argument("--run-name", required=True) |
| parser.add_argument("--max-samples", type=int, default=None) |
| parser.add_argument("--batch-size", type=int, default=1) |
| parser.add_argument("--max-new-tokens", type=int, default=1536) |
| parser.add_argument("--max-input-tokens", type=int, default=6144) |
| args = parser.parse_args() |
|
|
| out_root = Path(args.output_dir) |
| pred_dir = out_root / "predictions" |
| metric_dir = out_root / "metrics" |
| pred_dir.mkdir(parents=True, exist_ok=True) |
| metric_dir.mkdir(parents=True, exist_ok=True) |
| pred_path = pred_dir / f"{args.run_name}_{args.task_type}_predictions.jsonl" |
|
|
| if args.predictions_file: |
| records = read_jsonl(args.predictions_file, args.max_samples) |
| else: |
| if not args.input_file: |
| raise ValueError("--input-file is required unless --predictions-file is provided") |
| rows = read_jsonl(args.input_file, args.max_samples) |
| tokenizer, model = load_model(args.model_name, args.adapter_dir) |
| preferred_fields = STRUCT_FIELDS if args.task_type == "struct" else QA_FIELDS |
| records = generate_predictions( |
| rows, tokenizer, model, args.max_new_tokens, args.batch_size, preferred_fields, args.max_input_tokens, pred_path |
| ) |
| metrics = evaluate_struct(records) if args.task_type == "struct" else evaluate_qa(records) |
|
|
| metric_payload = { |
| "run_name": args.run_name, |
| "task_type": args.task_type, |
| "input_file": args.input_file, |
| "predictions_file": args.predictions_file, |
| "metrics": metrics, |
| } |
| metric_path = metric_dir / f"{args.run_name}_{args.task_type}_metrics.json" |
| metric_path.write_text(json.dumps(metric_payload, ensure_ascii=False, indent=2), encoding="utf-8") |
| print(json.dumps(metric_payload, ensure_ascii=False, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|