Instructions to use nraptisss/tmf921-intent-training with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use nraptisss/tmf921-intent-training with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """ | |
| Baseline evaluation script for TMF921 intent-to-config benchmark. | |
| Supports local models (Llama, Qwen, etc.) and API models (GPT-4o-mini). | |
| Usage (local): | |
| python scripts/baseline_eval.py \ | |
| --model meta-llama/Llama-3.1-8B-Instruct \ | |
| --output_dir outputs/baselines/llama-3.1-8b \ | |
| --batch_size 4 | |
| Usage (API): | |
| export OPENAI_API_KEY=sk-... | |
| python scripts/baseline_eval.py \ | |
| --model gpt-4o-mini \ | |
| --api_provider openai \ | |
| --output_dir outputs/baselines/gpt-4o-mini \ | |
| --batch_size 1 | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| import torch | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| # Add project src to path for utils | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) | |
| from tmf921_train.utils import ( | |
| aggregate_metrics, field_f1, get_message, json_exact_match, | |
| metadata_constraint_pass, parse_json, write_json | |
| ) | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--model", required=True, help="Model ID or API model name") | |
| p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota") | |
| p.add_argument("--splits", nargs="+", default=[ | |
| "test_in_distribution", "test_template_ood", | |
| "test_use_case_ood", "test_sector_ood", "test_adversarial" | |
| ]) | |
| p.add_argument("--output_dir", required=True) | |
| p.add_argument("--max_samples_per_split", type=int, default=None) | |
| p.add_argument("--batch_size", type=int, default=4) | |
| p.add_argument("--max_new_tokens", type=int, default=1536) | |
| p.add_argument("--gold_length_buffer", type=int, default=96) | |
| p.add_argument("--save_every", type=int, default=25) | |
| p.add_argument("--temperature", type=float, default=0.0) | |
| p.add_argument("--top_p", type=float, default=1.0) | |
| p.add_argument("--api_provider", choices=["openai", "anthropic", "none"], default="none") | |
| p.add_argument("--resume", action="store_true", default=True) | |
| p.add_argument("--no_resume", dest="resume", action="store_false") | |
| p.add_argument("--trust_remote_code", action="store_true", default=True) | |
| return p.parse_args() | |
| def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]: | |
| out = [] | |
| for i, m in enumerate(messages): | |
| if i == len(messages) - 1 and m.get("role") == "assistant": | |
| break | |
| out.append({"role": m.get("role"), "content": m.get("content", "")}) | |
| if not out: | |
| out = [m for m in messages if m.get("role") != "assistant"] | |
| return out | |
| def make_prompt_text(tokenizer, messages: List[Dict[str, str]]) -> str: | |
| return tokenizer.apply_chat_template( | |
| make_prompt_messages(messages), tokenize=False, add_generation_prompt=True | |
| ) | |
| def gold_text(example: Dict[str, Any]) -> str: | |
| return example.get("completion") or get_message(example["messages"], "assistant") | |
| def dynamic_max_new_tokens(tokenizer, examples: List[Dict[str, Any]], args) -> int: | |
| lens = [] | |
| for ex in examples: | |
| ids = tokenizer(gold_text(ex), add_special_tokens=False)["input_ids"] | |
| lens.append(len(ids)) | |
| return max(16, min(int(args.max_new_tokens), max(lens) + int(args.gold_length_buffer))) | |
| # ─── Local model generation ───────────────────────────────────────────────── | |
| def load_local_model(model_id: str, trust_remote_code: bool = True): | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=trust_remote_code, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| model.eval() | |
| return model, tokenizer | |
| def generate_batch_local(model, tokenizer, examples: List[Dict[str, Any]], args) -> List[str]: | |
| texts = [make_prompt_text(tokenizer, ex["messages"]) for ex in examples] | |
| old_padding_side = tokenizer.padding_side | |
| tokenizer.padding_side = "left" | |
| try: | |
| inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device) | |
| finally: | |
| tokenizer.padding_side = old_padding_side | |
| max_new = dynamic_max_new_tokens(tokenizer, examples, args) | |
| gen_kwargs = dict( | |
| max_new_tokens=max_new, | |
| do_sample=args.temperature > 0, | |
| temperature=args.temperature if args.temperature > 0 else None, | |
| top_p=args.top_p, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None} | |
| with torch.inference_mode(): | |
| out = model.generate(**inputs, **gen_kwargs) | |
| new_tokens = out[:, inputs["input_ids"].shape[1]:] | |
| return tokenizer.batch_decode(new_tokens, skip_special_tokens=True) | |
| # ─── API generation ────────────────────────────────────────────────────────── | |
| def generate_single_api(model: str, messages: List[Dict[str, str]], max_tokens: int, temperature: float, top_p: float, provider: str) -> str: | |
| if provider == "openai": | |
| import openai | |
| client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| resp = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| return resp.choices[0].message.content or "" | |
| elif provider == "anthropic": | |
| import anthropic | |
| client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) | |
| system_msg = "" | |
| user_msgs = [] | |
| for m in messages: | |
| if m["role"] == "system": | |
| system_msg = m["content"] | |
| else: | |
| user_msgs.append({"role": m["role"], "content": m["content"]}) | |
| resp = client.messages.create( | |
| model=model, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| system=system_msg, | |
| messages=user_msgs, | |
| ) | |
| return resp.content[0].text if resp.content else "" | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| def generate_batch_api(model: str, examples: List[Dict[str, Any]], max_tokens: int, temperature: float, top_p: float, provider: str) -> List[str]: | |
| results = [] | |
| for ex in examples: | |
| msgs = make_prompt_messages(ex["messages"]) | |
| pred = generate_single_api(model, msgs, max_tokens, temperature, top_p, provider) | |
| results.append(pred) | |
| return results | |
| # ─── Evaluation ───────────────────────────────────────────────────────────── | |
| def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]: | |
| gold = gold_text(example) | |
| pred_obj, pred_err = parse_json(prediction) | |
| gold_obj, gold_err = parse_json(gold) | |
| out: Dict[str, Any] = { | |
| "id": example.get("id"), | |
| "target_layer": example.get("target_layer"), | |
| "slice_type": example.get("slice_type"), | |
| "lifecycle_operation": example.get("lifecycle_operation"), | |
| "parse_json": pred_obj is not None, | |
| "gold_parse_json": gold_obj is not None, | |
| "exact_match": False, | |
| "prediction": prediction, | |
| "gold": gold, | |
| "parse_error": pred_err, | |
| } | |
| if pred_obj is not None and gold_obj is not None: | |
| out["exact_match"] = json_exact_match(pred_obj, gold_obj) | |
| out.update(field_f1(pred_obj, gold_obj)) | |
| out.update(metadata_constraint_pass(example, prediction, pred_obj)) | |
| else: | |
| out.update({"field_precision": 0.0, "field_recall": 0.0, "field_f1": 0.0, "field_tp": 0, "field_fp": 0, "field_fn": 0}) | |
| out.update({"slice_sst_pass": False, "kpi_text_presence_pass": False, "adversarial_status_pass": False}) | |
| return out | |
| def load_existing_predictions(path: Path) -> Tuple[List[Dict[str, Any]], set]: | |
| if path.exists(): | |
| rows = json.loads(path.read_text()) | |
| done = {str(r.get("id")) for r in rows} | |
| return rows, done | |
| return [], set() | |
| def write_split_outputs(split_dir: Path, rows: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| write_json(split_dir / "predictions.json", rows) | |
| summary = aggregate_metrics(rows) | |
| for key in ["target_layer", "slice_type", "lifecycle_operation"]: | |
| groups = defaultdict(list) | |
| for r in rows: | |
| groups[str(r.get(key))].append(r) | |
| summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())} | |
| write_json(split_dir / "metrics.json", summary) | |
| return summary | |
| def main(): | |
| args = parse_args() | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| write_json(out_dir / "baseline_config.json", vars(args)) | |
| is_api = args.api_provider != "none" | |
| if not is_api: | |
| print(f"Loading local model: {args.model}") | |
| model, tokenizer = load_local_model(args.model, args.trust_remote_code) | |
| else: | |
| print(f"Using API provider: {args.api_provider}, model: {args.model}") | |
| model, tokenizer = None, None | |
| ds = load_dataset(args.dataset) | |
| all_summary = {} | |
| for split in args.splits: | |
| split_ds = ds[split] | |
| if args.max_samples_per_split: | |
| split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds)))) | |
| split_dir = out_dir / split | |
| split_dir.mkdir(parents=True, exist_ok=True) | |
| pred_path = split_dir / "predictions.json" | |
| rows, done_ids = load_existing_predictions(pred_path) if args.resume else ([], set()) | |
| todo = [ex for ex in split_ds if str(ex.get("id")) not in done_ids] | |
| print(f"\nEvaluating {split}: total={len(split_ds)} already_done={len(done_ids)} remaining={len(todo)} batch_size={args.batch_size}") | |
| if len(todo) == 0: | |
| summary = write_split_outputs(split_dir, rows) | |
| all_summary[split] = summary | |
| continue | |
| pbar = tqdm(total=len(todo), desc=split) | |
| completed_since_save = 0 | |
| for start in range(0, len(todo), args.batch_size): | |
| batch = todo[start:start + args.batch_size] | |
| try: | |
| if is_api: | |
| max_tokens = args.max_new_tokens | |
| preds = generate_batch_api(args.model, batch, max_tokens, args.temperature, args.top_p, args.api_provider) | |
| else: | |
| preds = generate_batch_local(model, tokenizer, batch, args) | |
| except Exception as e: | |
| print(f"\nERROR in batch starting at {start}: {e}") | |
| if is_api: | |
| preds = [] | |
| for ex in batch: | |
| try: | |
| pred = generate_single_api(args.model, make_prompt_messages(ex["messages"]), args.max_new_tokens, args.temperature, args.top_p, args.api_provider) | |
| preds.append(pred) | |
| except Exception as e2: | |
| print(f" Failed on example {ex.get('id')}: {e2}") | |
| preds.append("") | |
| else: | |
| raise | |
| for ex, pred in zip(batch, preds): | |
| rows.append(row_metrics(ex, pred.strip())) | |
| pbar.update(len(batch)) | |
| completed_since_save += len(batch) | |
| if completed_since_save >= args.save_every: | |
| write_split_outputs(split_dir, rows) | |
| completed_since_save = 0 | |
| pbar.close() | |
| summary = write_split_outputs(split_dir, rows) | |
| all_summary[split] = summary | |
| write_json(out_dir / "all_metrics.json", all_summary) | |
| print(f" {split}: parse={summary.get('parse_json', 0):.4f} field_f1={summary.get('field_f1', 0):.4f} exact_match={summary.get('exact_match', 0):.4f}") | |
| print("\n" + "=" * 60) | |
| print("BASELINE EVALUATION COMPLETE") | |
| print("=" * 60) | |
| for split, s in all_summary.items(): | |
| print(f"{split:30s}: parse={s.get('parse_json', 0):.4f} field_f1={s.get('field_f1', 0):.4f} exact={s.get('exact_match', 0):.4f}") | |
| if __name__ == "__main__": | |
| main() | |