PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
nraptisss commited on
Commit
d27f0bc
·
verified ·
1 Parent(s): 02caf44

Fix semantic evaluator to recover metadata by prediction id

Browse files
Files changed (1) hide show
  1. scripts/evaluate_semantic_o1_a1.py +44 -15
scripts/evaluate_semantic_o1_a1.py CHANGED
@@ -17,6 +17,8 @@ from collections import defaultdict
17
  from pathlib import Path
18
  from typing import Any, Dict, Iterable, List, Optional, Tuple
19
 
 
 
20
  from tmf921_train.utils import aggregate_metrics, parse_json, write_json
21
 
22
 
@@ -69,16 +71,31 @@ def any_key_contains(obj: Any, fragments: List[str]) -> bool:
69
  return False
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def semantic_score_o1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
73
  checks = {}
74
- checks["sem_sst"] = contains_value(pred_obj, int(example.get("sst")))
75
- checks["sem_sd"] = contains_value(pred_obj, example.get("sd"))
76
  checks["sem_snssai"] = checks["sem_sst"] and checks["sem_sd"]
77
- checks["sem_slice_type"] = contains_value(pred_obj, example.get("slice_type"))
78
- checks["sem_latency"] = contains_value(pred_obj, example.get("latency_ms"), rel=0.05)
79
- checks["sem_dl"] = contains_value(pred_obj, example.get("dl_throughput_mbps"), rel=0.05)
80
- checks["sem_ul"] = contains_value(pred_obj, example.get("ul_throughput_mbps"), rel=0.05)
81
- checks["sem_max_ues"] = contains_value(pred_obj, example.get("max_ues"), rel=0.05)
82
  checks["sem_managed_element_structure"] = any_key_contains(pred_obj, ["ManagedElement", "GNBDU", "NRCell", "cell", "rrmPolicy"])
83
  checks["sem_rrm_policy_structure"] = any_key_contains(pred_obj, ["rrmPolicy", "prb", "ratio", "quota"])
84
  checks["sem_cell_parameter_structure"] = any_key_contains(pred_obj, ["arfcn", "pci", "tac", "bandwidth", "bSChannelBw", "cellLocalId"])
@@ -92,18 +109,18 @@ def semantic_score_o1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
92
 
93
  def semantic_score_a1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
94
  checks = {}
95
- checks["sem_sst"] = contains_value(pred_obj, int(example.get("sst")))
96
- checks["sem_sd"] = contains_value(pred_obj, example.get("sd"))
97
  checks["sem_snssai"] = checks["sem_sst"] and checks["sem_sd"]
98
- checks["sem_slice_type"] = contains_value(pred_obj, example.get("slice_type"))
99
  checks["sem_policy_structure"] = any_key_contains(pred_obj, ["policy", "a1", "scope", "objective", "qos", "qoe"])
100
  checks["sem_prb_or_quota_structure"] = any_key_contains(pred_obj, ["prb", "quota", "resource", "allocation", "ratio"])
101
  checks["sem_scheduler_structure"] = any_key_contains(pred_obj, ["scheduler", "weight", "priority"])
102
  checks["sem_5qi_or_qos_structure"] = any_key_contains(pred_obj, ["5qi", "fiveqi", "qci", "qos", "pdb", "per", "gfbr", "mfbr"])
103
- checks["sem_latency"] = contains_value(pred_obj, example.get("latency_ms"), rel=0.05)
104
- checks["sem_dl"] = contains_value(pred_obj, example.get("dl_throughput_mbps"), rel=0.05)
105
- checks["sem_ul"] = contains_value(pred_obj, example.get("ul_throughput_mbps"), rel=0.05)
106
- checks["sem_max_ues"] = contains_value(pred_obj, example.get("max_ues"), rel=0.05)
107
  core_keys = ["sem_sst", "sem_sd", "sem_snssai", "sem_policy_structure", "sem_prb_or_quota_structure", "sem_5qi_or_qos_structure"]
108
  kpi_keys = ["sem_latency", "sem_dl", "sem_ul", "sem_max_ues", "sem_scheduler_structure"]
109
  checks["sem_core_score"] = sum(bool(checks[k]) for k in core_keys) / len(core_keys)
@@ -154,6 +171,7 @@ def summarize(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
154
  def main():
155
  ap = argparse.ArgumentParser()
156
  ap.add_argument("--eval_dir", required=True, help="Eval dir with split/predictions.json and optionally normalized_predictions_scored.json")
 
157
  ap.add_argument("--splits", nargs="*", default=None)
158
  args = ap.parse_args()
159
  eval_dir = Path(args.eval_dir)
@@ -161,6 +179,15 @@ def main():
161
  splits = args.splits
162
  else:
163
  splits = sorted([p.name for p in eval_dir.iterdir() if p.is_dir() and (p / "predictions.json").exists()])
 
 
 
 
 
 
 
 
 
164
  all_rows = []
165
  for split in splits:
166
  pred_path = eval_dir / split / "predictions.json"
@@ -172,6 +199,9 @@ def main():
172
  scored = []
173
  for r in rows:
174
  merged = dict(r)
 
 
 
175
  for k, v in norm.get(r.get("id"), {}).items():
176
  if k not in merged:
177
  merged[k] = v
@@ -191,5 +221,4 @@ def main():
191
 
192
 
193
  if __name__ == "__main__":
194
- import argparse
195
  main()
 
17
  from pathlib import Path
18
  from typing import Any, Dict, Iterable, List, Optional, Tuple
19
 
20
+ from datasets import load_dataset
21
+
22
  from tmf921_train.utils import aggregate_metrics, parse_json, write_json
23
 
24
 
 
71
  return False
72
 
73
 
74
+ def safe_int(x: Any) -> Optional[int]:
75
+ try:
76
+ if x is None:
77
+ return None
78
+ return int(x)
79
+ except Exception:
80
+ return None
81
+
82
+
83
+ def check_contains_if_present(obj: Any, target: Any, *, rel: float = 0.02) -> bool:
84
+ if target is None:
85
+ return False
86
+ return contains_value(obj, target, rel=rel)
87
+
88
+
89
  def semantic_score_o1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
90
  checks = {}
91
+ checks["sem_sst"] = check_contains_if_present(pred_obj, safe_int(example.get("sst")))
92
+ checks["sem_sd"] = check_contains_if_present(pred_obj, example.get("sd"))
93
  checks["sem_snssai"] = checks["sem_sst"] and checks["sem_sd"]
94
+ checks["sem_slice_type"] = check_contains_if_present(pred_obj, example.get("slice_type"))
95
+ checks["sem_latency"] = check_contains_if_present(pred_obj, example.get("latency_ms"), rel=0.05)
96
+ checks["sem_dl"] = check_contains_if_present(pred_obj, example.get("dl_throughput_mbps"), rel=0.05)
97
+ checks["sem_ul"] = check_contains_if_present(pred_obj, example.get("ul_throughput_mbps"), rel=0.05)
98
+ checks["sem_max_ues"] = check_contains_if_present(pred_obj, example.get("max_ues"), rel=0.05)
99
  checks["sem_managed_element_structure"] = any_key_contains(pred_obj, ["ManagedElement", "GNBDU", "NRCell", "cell", "rrmPolicy"])
100
  checks["sem_rrm_policy_structure"] = any_key_contains(pred_obj, ["rrmPolicy", "prb", "ratio", "quota"])
101
  checks["sem_cell_parameter_structure"] = any_key_contains(pred_obj, ["arfcn", "pci", "tac", "bandwidth", "bSChannelBw", "cellLocalId"])
 
109
 
110
  def semantic_score_a1(example: Dict[str, Any], pred_obj: Any) -> Dict[str, Any]:
111
  checks = {}
112
+ checks["sem_sst"] = check_contains_if_present(pred_obj, safe_int(example.get("sst")))
113
+ checks["sem_sd"] = check_contains_if_present(pred_obj, example.get("sd"))
114
  checks["sem_snssai"] = checks["sem_sst"] and checks["sem_sd"]
115
+ checks["sem_slice_type"] = check_contains_if_present(pred_obj, example.get("slice_type"))
116
  checks["sem_policy_structure"] = any_key_contains(pred_obj, ["policy", "a1", "scope", "objective", "qos", "qoe"])
117
  checks["sem_prb_or_quota_structure"] = any_key_contains(pred_obj, ["prb", "quota", "resource", "allocation", "ratio"])
118
  checks["sem_scheduler_structure"] = any_key_contains(pred_obj, ["scheduler", "weight", "priority"])
119
  checks["sem_5qi_or_qos_structure"] = any_key_contains(pred_obj, ["5qi", "fiveqi", "qci", "qos", "pdb", "per", "gfbr", "mfbr"])
120
+ checks["sem_latency"] = check_contains_if_present(pred_obj, example.get("latency_ms"), rel=0.05)
121
+ checks["sem_dl"] = check_contains_if_present(pred_obj, example.get("dl_throughput_mbps"), rel=0.05)
122
+ checks["sem_ul"] = check_contains_if_present(pred_obj, example.get("ul_throughput_mbps"), rel=0.05)
123
+ checks["sem_max_ues"] = check_contains_if_present(pred_obj, example.get("max_ues"), rel=0.05)
124
  core_keys = ["sem_sst", "sem_sd", "sem_snssai", "sem_policy_structure", "sem_prb_or_quota_structure", "sem_5qi_or_qos_structure"]
125
  kpi_keys = ["sem_latency", "sem_dl", "sem_ul", "sem_max_ues", "sem_scheduler_structure"]
126
  checks["sem_core_score"] = sum(bool(checks[k]) for k in core_keys) / len(core_keys)
 
171
  def main():
172
  ap = argparse.ArgumentParser()
173
  ap.add_argument("--eval_dir", required=True, help="Eval dir with split/predictions.json and optionally normalized_predictions_scored.json")
174
+ ap.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota", help="Dataset used to recover metadata such as sst/sd/KPIs by row id")
175
  ap.add_argument("--splits", nargs="*", default=None)
176
  args = ap.parse_args()
177
  eval_dir = Path(args.eval_dir)
 
179
  splits = args.splits
180
  else:
181
  splits = sorted([p.name for p in eval_dir.iterdir() if p.is_dir() and (p / "predictions.json").exists()])
182
+ print(f"Loading metadata from {args.dataset} for splits: {splits}")
183
+ ds = load_dataset(args.dataset)
184
+ meta_by_split_id = {}
185
+ needed_cols = ["id", "target_layer", "slice_type", "sst", "sd", "latency_ms", "dl_throughput_mbps", "ul_throughput_mbps", "max_ues"]
186
+ for split in splits:
187
+ if split in ds:
188
+ meta_by_split_id[split] = {str(r["id"]): r for r in ds[split].select_columns([c for c in needed_cols if c in ds[split].column_names])}
189
+ else:
190
+ meta_by_split_id[split] = {}
191
  all_rows = []
192
  for split in splits:
193
  pred_path = eval_dir / split / "predictions.json"
 
199
  scored = []
200
  for r in rows:
201
  merged = dict(r)
202
+ for k, v in meta_by_split_id.get(split, {}).get(str(r.get("id")), {}).items():
203
+ if k not in merged or merged.get(k) is None:
204
+ merged[k] = v
205
  for k, v in norm.get(r.get("id"), {}).items():
206
  if k not in merged:
207
  merged[k] = v
 
221
 
222
 
223
  if __name__ == "__main__":
 
224
  main()