PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
tmf921-intent-training / scripts /baseline_eval.py
nraptisss's picture
Add baseline evaluation script for Llama/GPT-4o-mini comparison
0e7b293 verified
#!/usr/bin/env python3
"""
Baseline evaluation script for TMF921 intent-to-config benchmark.
Supports local models (Llama, Qwen, etc.) and API models (GPT-4o-mini).
Usage (local):
python scripts/baseline_eval.py \
--model meta-llama/Llama-3.1-8B-Instruct \
--output_dir outputs/baselines/llama-3.1-8b \
--batch_size 4
Usage (API):
export OPENAI_API_KEY=sk-...
python scripts/baseline_eval.py \
--model gpt-4o-mini \
--api_provider openai \
--output_dir outputs/baselines/gpt-4o-mini \
--batch_size 1
"""
import argparse
import json
import os
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import torch
from datasets import load_dataset
from tqdm import tqdm
# Add project src to path for utils
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from tmf921_train.utils import (
aggregate_metrics, field_f1, get_message, json_exact_match,
metadata_constraint_pass, parse_json, write_json
)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", required=True, help="Model ID or API model name")
p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota")
p.add_argument("--splits", nargs="+", default=[
"test_in_distribution", "test_template_ood",
"test_use_case_ood", "test_sector_ood", "test_adversarial"
])
p.add_argument("--output_dir", required=True)
p.add_argument("--max_samples_per_split", type=int, default=None)
p.add_argument("--batch_size", type=int, default=4)
p.add_argument("--max_new_tokens", type=int, default=1536)
p.add_argument("--gold_length_buffer", type=int, default=96)
p.add_argument("--save_every", type=int, default=25)
p.add_argument("--temperature", type=float, default=0.0)
p.add_argument("--top_p", type=float, default=1.0)
p.add_argument("--api_provider", choices=["openai", "anthropic", "none"], default="none")
p.add_argument("--resume", action="store_true", default=True)
p.add_argument("--no_resume", dest="resume", action="store_false")
p.add_argument("--trust_remote_code", action="store_true", default=True)
return p.parse_args()
def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
out = []
for i, m in enumerate(messages):
if i == len(messages) - 1 and m.get("role") == "assistant":
break
out.append({"role": m.get("role"), "content": m.get("content", "")})
if not out:
out = [m for m in messages if m.get("role") != "assistant"]
return out
def make_prompt_text(tokenizer, messages: List[Dict[str, str]]) -> str:
return tokenizer.apply_chat_template(
make_prompt_messages(messages), tokenize=False, add_generation_prompt=True
)
def gold_text(example: Dict[str, Any]) -> str:
return example.get("completion") or get_message(example["messages"], "assistant")
def dynamic_max_new_tokens(tokenizer, examples: List[Dict[str, Any]], args) -> int:
lens = []
for ex in examples:
ids = tokenizer(gold_text(ex), add_special_tokens=False)["input_ids"]
lens.append(len(ids))
return max(16, min(int(args.max_new_tokens), max(lens) + int(args.gold_length_buffer)))
# ─── Local model generation ─────────────────────────────────────────────────
def load_local_model(model_id: str, trust_remote_code: bool = True):
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=trust_remote_code,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
return model, tokenizer
def generate_batch_local(model, tokenizer, examples: List[Dict[str, Any]], args) -> List[str]:
texts = [make_prompt_text(tokenizer, ex["messages"]) for ex in examples]
old_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"
try:
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
finally:
tokenizer.padding_side = old_padding_side
max_new = dynamic_max_new_tokens(tokenizer, examples, args)
gen_kwargs = dict(
max_new_tokens=max_new,
do_sample=args.temperature > 0,
temperature=args.temperature if args.temperature > 0 else None,
top_p=args.top_p,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
with torch.inference_mode():
out = model.generate(**inputs, **gen_kwargs)
new_tokens = out[:, inputs["input_ids"].shape[1]:]
return tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
# ─── API generation ──────────────────────────────────────────────────────────
def generate_single_api(model: str, messages: List[Dict[str, str]], max_tokens: int, temperature: float, top_p: float, provider: str) -> str:
if provider == "openai":
import openai
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
resp = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
return resp.choices[0].message.content or ""
elif provider == "anthropic":
import anthropic
client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
system_msg = ""
user_msgs = []
for m in messages:
if m["role"] == "system":
system_msg = m["content"]
else:
user_msgs.append({"role": m["role"], "content": m["content"]})
resp = client.messages.create(
model=model,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
system=system_msg,
messages=user_msgs,
)
return resp.content[0].text if resp.content else ""
else:
raise ValueError(f"Unknown provider: {provider}")
def generate_batch_api(model: str, examples: List[Dict[str, Any]], max_tokens: int, temperature: float, top_p: float, provider: str) -> List[str]:
results = []
for ex in examples:
msgs = make_prompt_messages(ex["messages"])
pred = generate_single_api(model, msgs, max_tokens, temperature, top_p, provider)
results.append(pred)
return results
# ─── Evaluation ─────────────────────────────────────────────────────────────
def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]:
gold = gold_text(example)
pred_obj, pred_err = parse_json(prediction)
gold_obj, gold_err = parse_json(gold)
out: Dict[str, Any] = {
"id": example.get("id"),
"target_layer": example.get("target_layer"),
"slice_type": example.get("slice_type"),
"lifecycle_operation": example.get("lifecycle_operation"),
"parse_json": pred_obj is not None,
"gold_parse_json": gold_obj is not None,
"exact_match": False,
"prediction": prediction,
"gold": gold,
"parse_error": pred_err,
}
if pred_obj is not None and gold_obj is not None:
out["exact_match"] = json_exact_match(pred_obj, gold_obj)
out.update(field_f1(pred_obj, gold_obj))
out.update(metadata_constraint_pass(example, prediction, pred_obj))
else:
out.update({"field_precision": 0.0, "field_recall": 0.0, "field_f1": 0.0, "field_tp": 0, "field_fp": 0, "field_fn": 0})
out.update({"slice_sst_pass": False, "kpi_text_presence_pass": False, "adversarial_status_pass": False})
return out
def load_existing_predictions(path: Path) -> Tuple[List[Dict[str, Any]], set]:
if path.exists():
rows = json.loads(path.read_text())
done = {str(r.get("id")) for r in rows}
return rows, done
return [], set()
def write_split_outputs(split_dir: Path, rows: List[Dict[str, Any]]) -> Dict[str, Any]:
write_json(split_dir / "predictions.json", rows)
summary = aggregate_metrics(rows)
for key in ["target_layer", "slice_type", "lifecycle_operation"]:
groups = defaultdict(list)
for r in rows:
groups[str(r.get(key))].append(r)
summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
write_json(split_dir / "metrics.json", summary)
return summary
def main():
args = parse_args()
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
write_json(out_dir / "baseline_config.json", vars(args))
is_api = args.api_provider != "none"
if not is_api:
print(f"Loading local model: {args.model}")
model, tokenizer = load_local_model(args.model, args.trust_remote_code)
else:
print(f"Using API provider: {args.api_provider}, model: {args.model}")
model, tokenizer = None, None
ds = load_dataset(args.dataset)
all_summary = {}
for split in args.splits:
split_ds = ds[split]
if args.max_samples_per_split:
split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds))))
split_dir = out_dir / split
split_dir.mkdir(parents=True, exist_ok=True)
pred_path = split_dir / "predictions.json"
rows, done_ids = load_existing_predictions(pred_path) if args.resume else ([], set())
todo = [ex for ex in split_ds if str(ex.get("id")) not in done_ids]
print(f"\nEvaluating {split}: total={len(split_ds)} already_done={len(done_ids)} remaining={len(todo)} batch_size={args.batch_size}")
if len(todo) == 0:
summary = write_split_outputs(split_dir, rows)
all_summary[split] = summary
continue
pbar = tqdm(total=len(todo), desc=split)
completed_since_save = 0
for start in range(0, len(todo), args.batch_size):
batch = todo[start:start + args.batch_size]
try:
if is_api:
max_tokens = args.max_new_tokens
preds = generate_batch_api(args.model, batch, max_tokens, args.temperature, args.top_p, args.api_provider)
else:
preds = generate_batch_local(model, tokenizer, batch, args)
except Exception as e:
print(f"\nERROR in batch starting at {start}: {e}")
if is_api:
preds = []
for ex in batch:
try:
pred = generate_single_api(args.model, make_prompt_messages(ex["messages"]), args.max_new_tokens, args.temperature, args.top_p, args.api_provider)
preds.append(pred)
except Exception as e2:
print(f" Failed on example {ex.get('id')}: {e2}")
preds.append("")
else:
raise
for ex, pred in zip(batch, preds):
rows.append(row_metrics(ex, pred.strip()))
pbar.update(len(batch))
completed_since_save += len(batch)
if completed_since_save >= args.save_every:
write_split_outputs(split_dir, rows)
completed_since_save = 0
pbar.close()
summary = write_split_outputs(split_dir, rows)
all_summary[split] = summary
write_json(out_dir / "all_metrics.json", all_summary)
print(f" {split}: parse={summary.get('parse_json', 0):.4f} field_f1={summary.get('field_f1', 0):.4f} exact_match={summary.get('exact_match', 0):.4f}")
print("\n" + "=" * 60)
print("BASELINE EVALUATION COMPLETE")
print("=" * 60)
for split, s in all_summary.items():
print(f"{split:30s}: parse={s.get('parse_json', 0):.4f} field_f1={s.get('field_f1', 0):.4f} exact={s.get('exact_match', 0):.4f}")
if __name__ == "__main__":
main()