Instructions to use nraptisss/tmf921-intent-training with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use nraptisss/tmf921-intent-training with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
Fix semantic evaluator to recover metadata by prediction id
Browse files
scripts/evaluate_semantic_o1_a1.py
CHANGED
|
@@ -17,6 +17,8 @@ from collections import defaultdict
|
|
| 17 |
from pathlib import Path
|
| 18 |
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
| 19 |
|
|
|
|
|
|
|
| 20 |
from tmf921_train.utils import aggregate_metrics, parse_json, write_json
|
| 21 |
|
| 22 |
|
|
@@ -69,16 +71,31 @@ def any_key_contains(obj: Any, fragments: List[str]) -> bool:
|
|
| 69 |
return False
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def semantic_score_o1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
|
| 73 |
checks = {}
|
| 74 |
-
checks["sem_sst"] =
|
| 75 |
-
checks["sem_sd"] =
|
| 76 |
checks["sem_snssai"] = checks["sem_sst"] and checks["sem_sd"]
|
| 77 |
-
checks["sem_slice_type"] =
|
| 78 |
-
checks["sem_latency"] =
|
| 79 |
-
checks["sem_dl"] =
|
| 80 |
-
checks["sem_ul"] =
|
| 81 |
-
checks["sem_max_ues"] =
|
| 82 |
checks["sem_managed_element_structure"] = any_key_contains(pred_obj, ["ManagedElement", "GNBDU", "NRCell", "cell", "rrmPolicy"])
|
| 83 |
checks["sem_rrm_policy_structure"] = any_key_contains(pred_obj, ["rrmPolicy", "prb", "ratio", "quota"])
|
| 84 |
checks["sem_cell_parameter_structure"] = any_key_contains(pred_obj, ["arfcn", "pci", "tac", "bandwidth", "bSChannelBw", "cellLocalId"])
|
|
@@ -92,18 +109,18 @@ def semantic_score_o1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
|
|
| 92 |
|
| 93 |
def semantic_score_a1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
|
| 94 |
checks = {}
|
| 95 |
-
checks["sem_sst"] =
|
| 96 |
-
checks["sem_sd"] =
|
| 97 |
checks["sem_snssai"] = checks["sem_sst"] and checks["sem_sd"]
|
| 98 |
-
checks["sem_slice_type"] =
|
| 99 |
checks["sem_policy_structure"] = any_key_contains(pred_obj, ["policy", "a1", "scope", "objective", "qos", "qoe"])
|
| 100 |
checks["sem_prb_or_quota_structure"] = any_key_contains(pred_obj, ["prb", "quota", "resource", "allocation", "ratio"])
|
| 101 |
checks["sem_scheduler_structure"] = any_key_contains(pred_obj, ["scheduler", "weight", "priority"])
|
| 102 |
checks["sem_5qi_or_qos_structure"] = any_key_contains(pred_obj, ["5qi", "fiveqi", "qci", "qos", "pdb", "per", "gfbr", "mfbr"])
|
| 103 |
-
checks["sem_latency"] =
|
| 104 |
-
checks["sem_dl"] =
|
| 105 |
-
checks["sem_ul"] =
|
| 106 |
-
checks["sem_max_ues"] =
|
| 107 |
core_keys = ["sem_sst", "sem_sd", "sem_snssai", "sem_policy_structure", "sem_prb_or_quota_structure", "sem_5qi_or_qos_structure"]
|
| 108 |
kpi_keys = ["sem_latency", "sem_dl", "sem_ul", "sem_max_ues", "sem_scheduler_structure"]
|
| 109 |
checks["sem_core_score"] = sum(bool(checks[k]) for k in core_keys) / len(core_keys)
|
|
@@ -154,6 +171,7 @@ def summarize(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
| 154 |
def main():
|
| 155 |
ap = argparse.ArgumentParser()
|
| 156 |
ap.add_argument("--eval_dir", required=True, help="Eval dir with split/predictions.json and optionally normalized_predictions_scored.json")
|
|
|
|
| 157 |
ap.add_argument("--splits", nargs="*", default=None)
|
| 158 |
args = ap.parse_args()
|
| 159 |
eval_dir = Path(args.eval_dir)
|
|
@@ -161,6 +179,15 @@ def main():
|
|
| 161 |
splits = args.splits
|
| 162 |
else:
|
| 163 |
splits = sorted([p.name for p in eval_dir.iterdir() if p.is_dir() and (p / "predictions.json").exists()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
all_rows = []
|
| 165 |
for split in splits:
|
| 166 |
pred_path = eval_dir / split / "predictions.json"
|
|
@@ -172,6 +199,9 @@ def main():
|
|
| 172 |
scored = []
|
| 173 |
for r in rows:
|
| 174 |
merged = dict(r)
|
|
|
|
|
|
|
|
|
|
| 175 |
for k, v in norm.get(r.get("id"), {}).items():
|
| 176 |
if k not in merged:
|
| 177 |
merged[k] = v
|
|
@@ -191,5 +221,4 @@ def main():
|
|
| 191 |
|
| 192 |
|
| 193 |
if __name__ == "__main__":
|
| 194 |
-
import argparse
|
| 195 |
main()
|
|
|
|
| 17 |
from pathlib import Path
|
| 18 |
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
| 19 |
|
| 20 |
+
from datasets import load_dataset
|
| 21 |
+
|
| 22 |
from tmf921_train.utils import aggregate_metrics, parse_json, write_json
|
| 23 |
|
| 24 |
|
|
|
|
| 71 |
return False
|
| 72 |
|
| 73 |
|
| 74 |
+
def safe_int(x: Any) -> Optional[int]:
|
| 75 |
+
try:
|
| 76 |
+
if x is None:
|
| 77 |
+
return None
|
| 78 |
+
return int(x)
|
| 79 |
+
except Exception:
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def check_contains_if_present(obj: Any, target: Any, *, rel: float = 0.02) -> bool:
|
| 84 |
+
if target is None:
|
| 85 |
+
return False
|
| 86 |
+
return contains_value(obj, target, rel=rel)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
def semantic_score_o1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
|
| 90 |
checks = {}
|
| 91 |
+
checks["sem_sst"] = check_contains_if_present(pred_obj, safe_int(example.get("sst")))
|
| 92 |
+
checks["sem_sd"] = check_contains_if_present(pred_obj, example.get("sd"))
|
| 93 |
checks["sem_snssai"] = checks["sem_sst"] and checks["sem_sd"]
|
| 94 |
+
checks["sem_slice_type"] = check_contains_if_present(pred_obj, example.get("slice_type"))
|
| 95 |
+
checks["sem_latency"] = check_contains_if_present(pred_obj, example.get("latency_ms"), rel=0.05)
|
| 96 |
+
checks["sem_dl"] = check_contains_if_present(pred_obj, example.get("dl_throughput_mbps"), rel=0.05)
|
| 97 |
+
checks["sem_ul"] = check_contains_if_present(pred_obj, example.get("ul_throughput_mbps"), rel=0.05)
|
| 98 |
+
checks["sem_max_ues"] = check_contains_if_present(pred_obj, example.get("max_ues"), rel=0.05)
|
| 99 |
checks["sem_managed_element_structure"] = any_key_contains(pred_obj, ["ManagedElement", "GNBDU", "NRCell", "cell", "rrmPolicy"])
|
| 100 |
checks["sem_rrm_policy_structure"] = any_key_contains(pred_obj, ["rrmPolicy", "prb", "ratio", "quota"])
|
| 101 |
checks["sem_cell_parameter_structure"] = any_key_contains(pred_obj, ["arfcn", "pci", "tac", "bandwidth", "bSChannelBw", "cellLocalId"])
|
|
|
|
| 109 |
|
| 110 |
def semantic_score_a1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
|
| 111 |
checks = {}
|
| 112 |
+
checks["sem_sst"] = check_contains_if_present(pred_obj, safe_int(example.get("sst")))
|
| 113 |
+
checks["sem_sd"] = check_contains_if_present(pred_obj, example.get("sd"))
|
| 114 |
checks["sem_snssai"] = checks["sem_sst"] and checks["sem_sd"]
|
| 115 |
+
checks["sem_slice_type"] = check_contains_if_present(pred_obj, example.get("slice_type"))
|
| 116 |
checks["sem_policy_structure"] = any_key_contains(pred_obj, ["policy", "a1", "scope", "objective", "qos", "qoe"])
|
| 117 |
checks["sem_prb_or_quota_structure"] = any_key_contains(pred_obj, ["prb", "quota", "resource", "allocation", "ratio"])
|
| 118 |
checks["sem_scheduler_structure"] = any_key_contains(pred_obj, ["scheduler", "weight", "priority"])
|
| 119 |
checks["sem_5qi_or_qos_structure"] = any_key_contains(pred_obj, ["5qi", "fiveqi", "qci", "qos", "pdb", "per", "gfbr", "mfbr"])
|
| 120 |
+
checks["sem_latency"] = check_contains_if_present(pred_obj, example.get("latency_ms"), rel=0.05)
|
| 121 |
+
checks["sem_dl"] = check_contains_if_present(pred_obj, example.get("dl_throughput_mbps"), rel=0.05)
|
| 122 |
+
checks["sem_ul"] = check_contains_if_present(pred_obj, example.get("ul_throughput_mbps"), rel=0.05)
|
| 123 |
+
checks["sem_max_ues"] = check_contains_if_present(pred_obj, example.get("max_ues"), rel=0.05)
|
| 124 |
core_keys = ["sem_sst", "sem_sd", "sem_snssai", "sem_policy_structure", "sem_prb_or_quota_structure", "sem_5qi_or_qos_structure"]
|
| 125 |
kpi_keys = ["sem_latency", "sem_dl", "sem_ul", "sem_max_ues", "sem_scheduler_structure"]
|
| 126 |
checks["sem_core_score"] = sum(bool(checks[k]) for k in core_keys) / len(core_keys)
|
|
|
|
| 171 |
def main():
|
| 172 |
ap = argparse.ArgumentParser()
|
| 173 |
ap.add_argument("--eval_dir", required=True, help="Eval dir with split/predictions.json and optionally normalized_predictions_scored.json")
|
| 174 |
+
ap.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota", help="Dataset used to recover metadata such as sst/sd/KPIs by row id")
|
| 175 |
ap.add_argument("--splits", nargs="*", default=None)
|
| 176 |
args = ap.parse_args()
|
| 177 |
eval_dir = Path(args.eval_dir)
|
|
|
|
| 179 |
splits = args.splits
|
| 180 |
else:
|
| 181 |
splits = sorted([p.name for p in eval_dir.iterdir() if p.is_dir() and (p / "predictions.json").exists()])
|
| 182 |
+
print(f"Loading metadata from {args.dataset} for splits: {splits}")
|
| 183 |
+
ds = load_dataset(args.dataset)
|
| 184 |
+
meta_by_split_id = {}
|
| 185 |
+
needed_cols = ["id", "target_layer", "slice_type", "sst", "sd", "latency_ms", "dl_throughput_mbps", "ul_throughput_mbps", "max_ues"]
|
| 186 |
+
for split in splits:
|
| 187 |
+
if split in ds:
|
| 188 |
+
meta_by_split_id[split] = {str(r["id"]): r for r in ds[split].select_columns([c for c in needed_cols if c in ds[split].column_names])}
|
| 189 |
+
else:
|
| 190 |
+
meta_by_split_id[split] = {}
|
| 191 |
all_rows = []
|
| 192 |
for split in splits:
|
| 193 |
pred_path = eval_dir / split / "predictions.json"
|
|
|
|
| 199 |
scored = []
|
| 200 |
for r in rows:
|
| 201 |
merged = dict(r)
|
| 202 |
+
for k, v in meta_by_split_id.get(split, {}).get(str(r.get("id")), {}).items():
|
| 203 |
+
if k not in merged or merged.get(k) is None:
|
| 204 |
+
merged[k] = v
|
| 205 |
for k, v in norm.get(r.get("id"), {}).items():
|
| 206 |
if k not in merged:
|
| 207 |
merged[k] = v
|
|
|
|
| 221 |
|
| 222 |
|
| 223 |
if __name__ == "__main__":
|
|
|
|
| 224 |
main()
|