#!/usr/bin/env python3 import argparse import json import math import re from pathlib import Path import numpy as np import torch from peft import PeftModel from sklearn.metrics import accuracy_score, f1_score, mean_absolute_error, precision_recall_fscore_support from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig STRUCT_FIELDS = [ "current_behavior", "is_transition", "elapsed_seconds_in_current_behavior", "estimated_remaining_seconds", "full_remaining_seconds", "expected_end_time", "next_possible_behavior", "stage_index", "total_stages", "sequence_so_far", ] TIME_FIELDS = [ "elapsed_seconds_in_current_behavior", "estimated_remaining_seconds", "full_remaining_seconds", "expected_end_time", ] QA_FIELDS = ["occupied", "time_to_free_minutes", "used_areas", "is_abnormal"] def read_jsonl(path, limit=None): rows = [] with open(path, encoding="utf-8") as f: for line in f: if not line.strip(): continue rows.append(json.loads(line)) if limit and len(rows) >= limit: break return rows def load_model(model_name, adapter_dir=None): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, ) if adapter_dir: model = PeftModel.from_pretrained(model, adapter_dir) model.eval() return tokenizer, model def render_prompt(tokenizer, messages): try: return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) except TypeError: return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) def json_candidates(text): decoder = json.JSONDecoder() for idx, char in enumerate(text): if char != "{": continue try: obj, _ = decoder.raw_decode(text[idx:]) except Exception: continue if isinstance(obj, dict): yield obj def parse_json_text(text, preferred_fields=None): text = text.strip() try: return json.loads(text), None except Exception: pass candidates = list(json_candidates(text)) if not candidates: return None, "no_json_object" if preferred_fields: preferred = set(preferred_fields) candidates.sort(key=lambda obj: len(preferred & set(obj.keys())), reverse=True) return candidates[0], None def generate_predictions(rows, tokenizer, model, max_new_tokens, batch_size, preferred_fields, max_input_tokens, pred_path=None): records = [] pred_file = pred_path.open("w", encoding="utf-8") if pred_path else None for start in range(0, len(rows), batch_size): batch = rows[start : start + batch_size] prompts = [render_prompt(tokenizer, row["messages"][:-1]) for row in batch] inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_input_tokens).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=None, top_p=None, pad_token_id=tokenizer.eos_token_id, ) prompt_len = inputs["input_ids"].shape[1] decoded = tokenizer.batch_decode(outputs[:, prompt_len:], skip_special_tokens=True) for row, pred_text in zip(batch, decoded): target_content = row["messages"][-1]["content"] target = json.loads(target_content) if isinstance(target_content, str) else target_content pred, error = parse_json_text(pred_text, preferred_fields) record = {"target": target, "prediction": pred, "raw_prediction": pred_text, "parse_error": error} records.append(record) if pred_file: pred_file.write(json.dumps(record, ensure_ascii=False, separators=(",", ":")) + "\n") pred_file.flush() print(f"generated {min(start + batch_size, len(rows))}/{len(rows)}", flush=True) if pred_file: pred_file.close() return records def safe_eq(a, b): return a == b def numeric_pairs(records, field): y_true, y_pred = [], [] for rec in records: pred = rec["prediction"] if not isinstance(pred, dict): continue t, p = rec["target"].get(field), pred.get(field) if isinstance(t, (int, float)) and isinstance(p, (int, float)) and math.isfinite(float(p)): y_true.append(float(t)) y_pred.append(float(p)) return y_true, y_pred def classification_metrics(records, field): pairs = [] for rec in records: pred = rec["prediction"] if isinstance(pred, dict) and field in pred: pairs.append((rec["target"].get(field), pred.get(field))) if not pairs: return {"accuracy": 0.0, "macro_f1": 0.0, "coverage": 0.0} y_true, y_pred = zip(*pairs) # sklearn cannot sort mixed labels such as None and str; normalize only for metric computation. y_true = ["" if value is None else str(value) for value in y_true] y_pred = ["" if value is None else str(value) for value in y_pred] return { "accuracy": float(accuracy_score(y_true, y_pred)), "macro_f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)), "coverage": len(pairs) / len(records), } def sequence_metrics(records): exact = [] last = [] prefix = [] for rec in records: pred = rec["prediction"] if not isinstance(pred, dict): continue true_seq = [x.get("label") for x in rec["target"].get("sequence_so_far") or []] pred_seq = [x.get("label") for x in pred.get("sequence_so_far") or [] if isinstance(x, dict)] exact.append(true_seq == pred_seq) last.append(bool(true_seq and pred_seq and true_seq[-1] == pred_seq[-1])) prefix_len = min(len(true_seq), len(pred_seq)) prefix.append(sum(1 for i in range(prefix_len) if true_seq[i] == pred_seq[i]) / max(1, len(true_seq))) return { "sequence_exact_match": float(np.mean(exact)) if exact else 0.0, "sequence_last_label_accuracy": float(np.mean(last)) if last else 0.0, "sequence_prefix_label_match": float(np.mean(prefix)) if prefix else 0.0, } def evaluate_struct(records): parsed = [r for r in records if isinstance(r["prediction"], dict)] metrics = { "num_examples": len(records), "json_parse_rate": len(parsed) / max(1, len(records)), "required_field_complete_rate": sum(all(f in r["prediction"] for f in STRUCT_FIELDS) for r in parsed) / max(1, len(records)), } for field in ["current_behavior", "next_possible_behavior", "is_transition", "stage_index", "total_stages"]: cm = classification_metrics(records, field) metrics[f"{field}_accuracy"] = cm["accuracy"] if "behavior" in field or field == "is_transition": metrics[f"{field}_macro_f1"] = cm["macro_f1"] for field in TIME_FIELDS: y_true, y_pred = numeric_pairs(records, field) metrics[f"{field}_mae"] = float(mean_absolute_error(y_true, y_pred)) if y_true else None metrics[f"{field}_coverage"] = len(y_true) / max(1, len(records)) metrics.update(sequence_metrics(records)) return metrics def normalize_areas(value): if not isinstance(value, list): return set() return {str(x) for x in value} def evaluate_qa(records): parsed = [r for r in records if isinstance(r["prediction"], dict)] metrics = { "num_examples": len(records), "json_parse_rate": len(parsed) / max(1, len(records)), "required_field_complete_rate": sum(all(f in r["prediction"] for f in QA_FIELDS) for r in parsed) / max(1, len(records)), } for field in ["occupied", "is_abnormal"]: cm = classification_metrics(records, field) metrics[f"{field}_accuracy"] = cm["accuracy"] metrics[f"{field}_f1"] = cm["macro_f1"] y_true, y_pred = numeric_pairs(records, "time_to_free_minutes") metrics["time_to_free_minutes_mae"] = float(mean_absolute_error(y_true, y_pred)) if y_true else None true_flat, pred_flat = [], [] labels = ["门", "马桶", "洗手池", "垃圾桶"] for rec in records: pred = rec["prediction"] if not isinstance(pred, dict): continue t = normalize_areas(rec["target"].get("used_areas")) p = normalize_areas(pred.get("used_areas")) true_flat.extend([label in t for label in labels]) pred_flat.extend([label in p for label in labels]) if true_flat: pr, rc, f1, _ = precision_recall_fscore_support(true_flat, pred_flat, average="binary", zero_division=0) metrics["used_areas_micro_precision"] = float(pr) metrics["used_areas_micro_recall"] = float(rc) metrics["used_areas_micro_f1"] = float(f1) return metrics def main(): parser = argparse.ArgumentParser() parser.add_argument("--model-name", default="Qwen/Qwen3.5-9B") parser.add_argument("--adapter-dir", default=None) parser.add_argument("--input-file", default=None) parser.add_argument("--predictions-file", default=None) parser.add_argument("--task-type", choices=["struct", "qa"], required=True) parser.add_argument("--output-dir", default="outputs") parser.add_argument("--run-name", required=True) parser.add_argument("--max-samples", type=int, default=None) parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--max-new-tokens", type=int, default=1536) parser.add_argument("--max-input-tokens", type=int, default=6144) args = parser.parse_args() out_root = Path(args.output_dir) pred_dir = out_root / "predictions" metric_dir = out_root / "metrics" pred_dir.mkdir(parents=True, exist_ok=True) metric_dir.mkdir(parents=True, exist_ok=True) pred_path = pred_dir / f"{args.run_name}_{args.task_type}_predictions.jsonl" if args.predictions_file: records = read_jsonl(args.predictions_file, args.max_samples) else: if not args.input_file: raise ValueError("--input-file is required unless --predictions-file is provided") rows = read_jsonl(args.input_file, args.max_samples) tokenizer, model = load_model(args.model_name, args.adapter_dir) preferred_fields = STRUCT_FIELDS if args.task_type == "struct" else QA_FIELDS records = generate_predictions( rows, tokenizer, model, args.max_new_tokens, args.batch_size, preferred_fields, args.max_input_tokens, pred_path ) metrics = evaluate_struct(records) if args.task_type == "struct" else evaluate_qa(records) metric_payload = { "run_name": args.run_name, "task_type": args.task_type, "input_file": args.input_file, "predictions_file": args.predictions_file, "metrics": metrics, } metric_path = metric_dir / f"{args.run_name}_{args.task_type}_metrics.json" metric_path.write_text(json.dumps(metric_payload, ensure_ascii=False, indent=2), encoding="utf-8") print(json.dumps(metric_payload, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()