#!/usr/bin/env python3 """ Baseline evaluation script for TMF921 intent-to-config benchmark. Supports local models (Llama, Qwen, etc.) and API models (GPT-4o-mini). Usage (local): python scripts/baseline_eval.py \ --model meta-llama/Llama-3.1-8B-Instruct \ --output_dir outputs/baselines/llama-3.1-8b \ --batch_size 4 Usage (API): export OPENAI_API_KEY=sk-... python scripts/baseline_eval.py \ --model gpt-4o-mini \ --api_provider openai \ --output_dir outputs/baselines/gpt-4o-mini \ --batch_size 1 """ import argparse import json import os import sys from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Tuple import torch from datasets import load_dataset from tqdm import tqdm # Add project src to path for utils sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from tmf921_train.utils import ( aggregate_metrics, field_f1, get_message, json_exact_match, metadata_constraint_pass, parse_json, write_json ) def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model", required=True, help="Model ID or API model name") p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota") p.add_argument("--splits", nargs="+", default=[ "test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial" ]) p.add_argument("--output_dir", required=True) p.add_argument("--max_samples_per_split", type=int, default=None) p.add_argument("--batch_size", type=int, default=4) p.add_argument("--max_new_tokens", type=int, default=1536) p.add_argument("--gold_length_buffer", type=int, default=96) p.add_argument("--save_every", type=int, default=25) p.add_argument("--temperature", type=float, default=0.0) p.add_argument("--top_p", type=float, default=1.0) p.add_argument("--api_provider", choices=["openai", "anthropic", "none"], default="none") p.add_argument("--resume", action="store_true", default=True) p.add_argument("--no_resume", dest="resume", action="store_false") p.add_argument("--trust_remote_code", action="store_true", default=True) return p.parse_args() def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]: out = [] for i, m in enumerate(messages): if i == len(messages) - 1 and m.get("role") == "assistant": break out.append({"role": m.get("role"), "content": m.get("content", "")}) if not out: out = [m for m in messages if m.get("role") != "assistant"] return out def make_prompt_text(tokenizer, messages: List[Dict[str, str]]) -> str: return tokenizer.apply_chat_template( make_prompt_messages(messages), tokenize=False, add_generation_prompt=True ) def gold_text(example: Dict[str, Any]) -> str: return example.get("completion") or get_message(example["messages"], "assistant") def dynamic_max_new_tokens(tokenizer, examples: List[Dict[str, Any]], args) -> int: lens = [] for ex in examples: ids = tokenizer(gold_text(ex), add_special_tokens=False)["input_ids"] lens.append(len(ids)) return max(16, min(int(args.max_new_tokens), max(lens) + int(args.gold_length_buffer))) # ─── Local model generation ───────────────────────────────────────────────── def load_local_model(model_id: str, trust_remote_code: bool = True): from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=trust_remote_code, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, ) model.eval() return model, tokenizer def generate_batch_local(model, tokenizer, examples: List[Dict[str, Any]], args) -> List[str]: texts = [make_prompt_text(tokenizer, ex["messages"]) for ex in examples] old_padding_side = tokenizer.padding_side tokenizer.padding_side = "left" try: inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device) finally: tokenizer.padding_side = old_padding_side max_new = dynamic_max_new_tokens(tokenizer, examples, args) gen_kwargs = dict( max_new_tokens=max_new, do_sample=args.temperature > 0, temperature=args.temperature if args.temperature > 0 else None, top_p=args.top_p, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None} with torch.inference_mode(): out = model.generate(**inputs, **gen_kwargs) new_tokens = out[:, inputs["input_ids"].shape[1]:] return tokenizer.batch_decode(new_tokens, skip_special_tokens=True) # ─── API generation ────────────────────────────────────────────────────────── def generate_single_api(model: str, messages: List[Dict[str, str]], max_tokens: int, temperature: float, top_p: float, provider: str) -> str: if provider == "openai": import openai client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) resp = client.chat.completions.create( model=model, messages=messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, ) return resp.choices[0].message.content or "" elif provider == "anthropic": import anthropic client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) system_msg = "" user_msgs = [] for m in messages: if m["role"] == "system": system_msg = m["content"] else: user_msgs.append({"role": m["role"], "content": m["content"]}) resp = client.messages.create( model=model, max_tokens=max_tokens, temperature=temperature, top_p=top_p, system=system_msg, messages=user_msgs, ) return resp.content[0].text if resp.content else "" else: raise ValueError(f"Unknown provider: {provider}") def generate_batch_api(model: str, examples: List[Dict[str, Any]], max_tokens: int, temperature: float, top_p: float, provider: str) -> List[str]: results = [] for ex in examples: msgs = make_prompt_messages(ex["messages"]) pred = generate_single_api(model, msgs, max_tokens, temperature, top_p, provider) results.append(pred) return results # ─── Evaluation ───────────────────────────────────────────────────────────── def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]: gold = gold_text(example) pred_obj, pred_err = parse_json(prediction) gold_obj, gold_err = parse_json(gold) out: Dict[str, Any] = { "id": example.get("id"), "target_layer": example.get("target_layer"), "slice_type": example.get("slice_type"), "lifecycle_operation": example.get("lifecycle_operation"), "parse_json": pred_obj is not None, "gold_parse_json": gold_obj is not None, "exact_match": False, "prediction": prediction, "gold": gold, "parse_error": pred_err, } if pred_obj is not None and gold_obj is not None: out["exact_match"] = json_exact_match(pred_obj, gold_obj) out.update(field_f1(pred_obj, gold_obj)) out.update(metadata_constraint_pass(example, prediction, pred_obj)) else: out.update({"field_precision": 0.0, "field_recall": 0.0, "field_f1": 0.0, "field_tp": 0, "field_fp": 0, "field_fn": 0}) out.update({"slice_sst_pass": False, "kpi_text_presence_pass": False, "adversarial_status_pass": False}) return out def load_existing_predictions(path: Path) -> Tuple[List[Dict[str, Any]], set]: if path.exists(): rows = json.loads(path.read_text()) done = {str(r.get("id")) for r in rows} return rows, done return [], set() def write_split_outputs(split_dir: Path, rows: List[Dict[str, Any]]) -> Dict[str, Any]: write_json(split_dir / "predictions.json", rows) summary = aggregate_metrics(rows) for key in ["target_layer", "slice_type", "lifecycle_operation"]: groups = defaultdict(list) for r in rows: groups[str(r.get(key))].append(r) summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())} write_json(split_dir / "metrics.json", summary) return summary def main(): args = parse_args() out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) write_json(out_dir / "baseline_config.json", vars(args)) is_api = args.api_provider != "none" if not is_api: print(f"Loading local model: {args.model}") model, tokenizer = load_local_model(args.model, args.trust_remote_code) else: print(f"Using API provider: {args.api_provider}, model: {args.model}") model, tokenizer = None, None ds = load_dataset(args.dataset) all_summary = {} for split in args.splits: split_ds = ds[split] if args.max_samples_per_split: split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds)))) split_dir = out_dir / split split_dir.mkdir(parents=True, exist_ok=True) pred_path = split_dir / "predictions.json" rows, done_ids = load_existing_predictions(pred_path) if args.resume else ([], set()) todo = [ex for ex in split_ds if str(ex.get("id")) not in done_ids] print(f"\nEvaluating {split}: total={len(split_ds)} already_done={len(done_ids)} remaining={len(todo)} batch_size={args.batch_size}") if len(todo) == 0: summary = write_split_outputs(split_dir, rows) all_summary[split] = summary continue pbar = tqdm(total=len(todo), desc=split) completed_since_save = 0 for start in range(0, len(todo), args.batch_size): batch = todo[start:start + args.batch_size] try: if is_api: max_tokens = args.max_new_tokens preds = generate_batch_api(args.model, batch, max_tokens, args.temperature, args.top_p, args.api_provider) else: preds = generate_batch_local(model, tokenizer, batch, args) except Exception as e: print(f"\nERROR in batch starting at {start}: {e}") if is_api: preds = [] for ex in batch: try: pred = generate_single_api(args.model, make_prompt_messages(ex["messages"]), args.max_new_tokens, args.temperature, args.top_p, args.api_provider) preds.append(pred) except Exception as e2: print(f" Failed on example {ex.get('id')}: {e2}") preds.append("") else: raise for ex, pred in zip(batch, preds): rows.append(row_metrics(ex, pred.strip())) pbar.update(len(batch)) completed_since_save += len(batch) if completed_since_save >= args.save_every: write_split_outputs(split_dir, rows) completed_since_save = 0 pbar.close() summary = write_split_outputs(split_dir, rows) all_summary[split] = summary write_json(out_dir / "all_metrics.json", all_summary) print(f" {split}: parse={summary.get('parse_json', 0):.4f} field_f1={summary.get('field_f1', 0):.4f} exact_match={summary.get('exact_match', 0):.4f}") print("\n" + "=" * 60) print("BASELINE EVALUATION COMPLETE") print("=" * 60) for split, s in all_summary.items(): print(f"{split:30s}: parse={s.get('parse_json', 0):.4f} field_f1={s.get('field_f1', 0):.4f} exact={s.get('exact_match', 0):.4f}") if __name__ == "__main__": main()