from __future__ import annotations import argparse import json import sys from pathlib import Path import torch _METRICS_DIR = Path(__file__).resolve().parent _SCRIPTS_DIR = _METRICS_DIR.parent for path in (_METRICS_DIR, _SCRIPTS_DIR): if str(path) not in sys.path: sys.path.insert(0, str(path)) from broken_code_generation import ( # noqa: E402 ADAPTER_DIR, DEFAULT_EVAL_LIMIT, EVAL_FILE, FILE_JSON_VALIDITY, GEN_MAX_NEW_TOKENS, GEN_SEED, GEN_TEMPERATURE, GEN_TOP_P, MODEL_ID, ) from evaluate_model import REQUIRED_FIELDS, generate_one, load_model_and_tokenizer # noqa: E402 from report_io import metrics_path, write_report # noqa: E402 def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=f"JSON validity for {MODEL_ID} only (adapter at {ADAPTER_DIR})." ) parser.add_argument("--limit", type=int, default=DEFAULT_EVAL_LIMIT) parser.add_argument("--output", type=Path, default=None) return parser.parse_args() def main() -> None: args = parse_args() torch.manual_seed(GEN_SEED) if not ADAPTER_DIR.exists(): raise FileNotFoundError(f"Adapter not found: {ADAPTER_DIR}") records = json.loads(EVAL_FILE.read_text(encoding="utf-8"))[: args.limit] print(f"Model: {MODEL_ID}") print(f"Adapter: {ADAPTER_DIR}") print(f"Samples: {len(records)} from {EVAL_FILE}") model, tokenizer = load_model_and_tokenizer(ADAPTER_DIR) model.eval() valid_json = required = difficulty_ok = tags_ok = 0 results = [] for index, record in enumerate(records, start=1): row = {"index": index, "status": "error"} try: generated = generate_one( model=model, tokenizer=tokenizer, topic_tags=record["topic_tags"], difficulty=record["difficulty"], max_new_tokens=GEN_MAX_NEW_TOKENS, temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, ) valid_json += 1 row["status"] = "ok" row["generated"] = generated if REQUIRED_FIELDS.issubset(generated): required += 1 if generated.get("difficulty") == record["difficulty"]: difficulty_ok += 1 if set(generated.get("topic_tags", {})) == set(record["topic_tags"]): tags_ok += 1 except Exception as error: # noqa: BLE001 row["error"] = str(error) results.append(row) print(f"[{MODEL_ID}] {index}/{len(records)} valid_json={valid_json}", flush=True) n = max(len(records), 1) report = { "metric_group": "json_validity", "model": MODEL_ID, "adapter_dir": str(ADAPTER_DIR), "evaluation_file": str(EVAL_FILE), "samples_evaluated": len(records), "generation": { "temperature": GEN_TEMPERATURE, "top_p": GEN_TOP_P, "max_new_tokens": GEN_MAX_NEW_TOKENS, "seed": GEN_SEED, }, "metrics": { "valid_json_rate": round(valid_json / n, 4), "required_fields_rate": round(required / n, 4), "difficulty_match_rate": round(difficulty_ok / n, 4), "topic_tag_key_match_rate": round(tags_ok / n, 4), }, "metrics_counts": { "valid_json": valid_json, "required_fields_complete": required, "difficulty_match": difficulty_ok, "topic_tag_keys_match": tags_ok, }, "results": results, } write_report(args.output or metrics_path(FILE_JSON_VALIDITY), report) if __name__ == "__main__": main()