#!/usr/bin/env python # -*- coding: utf-8 -*- """ src.model_core Shared model/runtime utilities for Track A. This module intentionally does not own the CLI. It provides the reusable pieces used by train.py and main.py: - JSON/env loading and result/debug/traces writers. - Server data hydration helpers and trace formatting. - Question/option parsing and answer normalization. - Telemetry feature extraction for scenarios and candidate options. - Prediction-time template scoring, option selection, and model bundle IO. The modelling approach uses two LightGBM components: 1. A template classifier: scenario-level telemetry features -> answer action template. 2. An option selector: scenario, option, target-cell stats, and template/action features -> whether a candidate option should be selected. The estimator fitting code belongs in train.py. The executable inference flow belongs in main.py. """ import csv import io import json import math import os import pickle import re import time import warnings from collections import Counter, defaultdict from typing import Any, Dict, List, Tuple, Optional warnings.filterwarnings( "ignore", message=r"`sklearn\.utils\.parallel\.delayed` should be used.*", category=UserWarning, ) import httpx import numpy as np import pandas as pd from openai import OpenAI from sklearn.pipeline import Pipeline warnings.filterwarnings("ignore") MODEL_BUNDLE_VERSION = 1 # ============================================================================= # General utilities # ============================================================================= def load_json(path: str) -> List[Dict[str, Any]]: with open(path, "r", encoding="utf-8") as f: obj = json.load(f) if isinstance(obj, list): return obj if isinstance(obj, dict): for k in ["data", "records", "scenarios"]: if isinstance(obj.get(k), list): return obj[k] raise ValueError(f"Unsupported JSON structure: {path}") def norm(x: Any) -> str: return re.sub(r"\s+", " ", str(x).strip().lower()) def load_env_file(path: Optional[str] = None): candidates = [] if path: candidates.append(path) candidates.extend( [ os.path.join(os.getcwd(), ".env"), os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env"), ] ) for candidate in candidates: if candidate and os.path.exists(candidate): try: from dotenv import load_dotenv load_dotenv(candidate, override=False) except Exception: with open(candidate, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line or line.startswith("#") or "=" not in line: continue key, value = line.split("=", 1) os.environ.setdefault(key.strip(), value.strip().strip("'\"")) return candidate return None def is_placeholder_data(value: Any) -> bool: return isinstance(value, str) and "use the api" in norm(value) def scenario_needs_server_data(s: Dict[str, Any]) -> bool: data = s.get("data", {}) or {} return any(is_placeholder_data(v) for v in data.values()) def compact_result_for_trace(result: Any, max_value_chars: int = 220) -> Any: if isinstance(result, str): if "\n" in result or len(result) > max_value_chars: return {"chars": len(result), "preview": result[:max_value_chars]} return result if isinstance(result, dict): out = {} for k, v in result.items(): if isinstance(v, str): if "\n" in v or len(v) > max_value_chars: out[k] = {"chars": len(v), "preview": v[:max_value_chars]} else: out[k] = v elif isinstance(v, list): out[k] = { "items": len(v), "preview": compact_result_for_trace(v[:2], max_value_chars), } elif isinstance(v, dict): out[k] = compact_result_for_trace(v, max_value_chars) else: out[k] = v return out if isinstance(result, list): return { "items": len(result), "preview": [ compact_result_for_trace(item, max_value_chars) for item in result[:2] ], } return result def format_tool_call(name: str, args: Dict[str, Any], result: Any) -> str: args_text = json.dumps(args, ensure_ascii=False) result_text = json.dumps(compact_result_for_trace(result), ensure_ascii=False) return f"Function: {name}, Arguments: {args_text}, Results: {result_text}" def safe_float(x: Any, default=np.nan) -> float: try: if x is None or x == "" or x == "-": return default y = float(x) if math.isnan(y) or math.isinf(y): return default return y except Exception: return default def read_pipe_csv(text: Any) -> pd.DataFrame: if not isinstance(text, str) or not text.strip(): return pd.DataFrame() try: return pd.read_csv(io.StringIO(text), sep="|") except Exception: try: return pd.read_csv(io.StringIO(text.replace("\r\n", "\n")), sep="|") except Exception: return pd.DataFrame() def numeric_series(df: pd.DataFrame, col: str) -> pd.Series: if df is None or df.empty or col not in df.columns: return pd.Series(dtype=float) return pd.to_numeric(df[col].replace("-", np.nan), errors="coerce") def stat_feats(prefix: str, s: pd.Series) -> Dict[str, float]: s = pd.to_numeric(s, errors="coerce").dropna() if len(s) == 0: return { f"{prefix}_mean": np.nan, f"{prefix}_min": np.nan, f"{prefix}_max": np.nan, f"{prefix}_std": np.nan, f"{prefix}_p25": np.nan, f"{prefix}_p50": np.nan, f"{prefix}_p75": np.nan, } return { f"{prefix}_mean": float(s.mean()), f"{prefix}_min": float(s.min()), f"{prefix}_max": float(s.max()), f"{prefix}_std": float(s.std(ddof=0)), f"{prefix}_p25": float(np.percentile(s, 25)), f"{prefix}_p50": float(np.percentile(s, 50)), f"{prefix}_p75": float(np.percentile(s, 75)), } def clean_features(d: Dict[str, Any]) -> Dict[str, float]: out = {} for k, v in d.items(): try: x = float(v) if math.isnan(x) or math.isinf(x): x = -999.0 out[k] = x except Exception: out[k] = 0.0 return out def parse_answer(ans: Any) -> List[str]: if not isinstance(ans, str): return [] if ans.strip().lower() in {"", "to be determined", "none", "nan"}: return [] return re.findall(r"C\d+", ans) def iou_score(pred: List[str], truth: List[str]) -> float: p, t = set(pred), set(truth) if not p and not t: return 1.0 if not p or not t: return 0.0 return len(p & t) / len(p | t) def is_multi_task(s: Dict[str, Any]) -> bool: tag = norm(s.get("tag", "")) desc = norm(s.get("task", {}).get("description", "")) return ( "multiple" in tag or "two to four" in desc or "2-4" in desc or "2 to 4" in desc ) def get_options(s: Dict[str, Any]) -> Dict[str, str]: opts = s.get("task", {}).get("options", []) out = {} if isinstance(opts, list): for o in opts: if isinstance(o, dict) and o.get("id"): out[str(o["id"])] = str(o.get("label", "")) elif isinstance(opts, dict): for k, v in opts.items(): if re.match(r"^C\d+$", str(k)): out[str(k)] = str(v) return out def build_question_text(s: Dict[str, Any]) -> str: task = s.get("task", {}) or {} parts = [str(task.get("description", "")).strip()] for cid, label in sorted(get_options(s).items(), key=lambda x: int(x[0][1:])): parts.append(f"{cid}: {label}") return "\n".join([p for p in parts if p]) def make_completion(labels: List[str], dbg: Dict[str, Any], s: Dict[str, Any]) -> str: opts = get_options(s) scores = { "template": dbg.get("template"), "template_prob": round(float(dbg.get("template_prob", 0.0)), 4), "top_templates": [ [tpl, round(float(prob), 4)] for tpl, prob in dbg.get("top_templates", [])[:5] ], } recs = "\n".join(f"- {cid}: {opts.get(cid, '')}" for cid in labels) if not recs: recs = "- No valid option selected" boxed = "|".join(labels) return ( "**Evidence:**\n" f"- Derived model signals: {json.dumps(scores, ensure_ascii=False)}\n\n" "**Recommendations:**\n" f"{recs}\n\n" f"\\boxed{{{boxed}}}" ) def normalize_prediction_labels(labels: List[str]) -> List[str]: clean = [] for label in labels: if isinstance(label, str) and re.match(r"^C\d+$", label.strip()): clean.append(label.strip()) return sorted(dict.fromkeys(clean), key=lambda x: int(x[1:])) MODEL_DELEGATION_SYSTEM_PROMPT = """ You are a wireless network optimization assistant. You must answer by calling the provided trained-model tool. The telemetry and server data have already been gathered by deterministic code; do not invent or request telemetry yourself. Use the trained model prediction as a strong prior, not as an unchangeable answer. The tool also returns compact option evidence: parsed action, target cell, change magnitude, current target-cell indicators, selector scores, and ambiguity groups where options differ only by target, subtype, or magnitude. Rerank the options when the evidence shows that another candidate better fits the same troubleshooting action, target cell, threshold subtype, or numerical change. Be conservative when the evidence is weak. Final answer rules: - Return C labels only, such as C5 or C5|C9. - For multiple labels, sort numerically by C number. - For a single-answer task, return exactly one label. - For a multiple-answer task, return two to four labels. - Do not include prose, markdown, or boxed notation in final_answer. """.strip() def get_env_value(names: str) -> str: for name in re.split(r"[,;]", names): value = os.environ.get(name.strip(), "").strip() if value: return value return "" def _time_left(deadline: float) -> float: if not deadline: return 10**9 return max(0.0, deadline - time.perf_counter()) def _llm_timeout_left(deadline: float, llm_timeout: float) -> Optional[float]: if not deadline: return llm_timeout remaining = _time_left(deadline) if remaining <= 0: return 0.1 return max(0.1, min(llm_timeout, remaining)) def _clip_text(text: str, max_chars: int) -> str: if max_chars <= 0 or len(text) <= max_chars: return text head = max_chars // 2 tail = max_chars - head return ( text[:head] + f"\n... [truncated {len(text) - max_chars} chars] ...\n" + text[-tail:] ) def _json_tool_content(result: Dict[str, Any], max_chars: int) -> str: text = json.dumps(result, ensure_ascii=False) return _clip_text(text, max_chars) def run_agent_with_trained_model( llm_client: Optional[OpenAI], model_name: str, template_model: Pipeline, selector_model: Pipeline, scenario: Dict[str, Any], timeout: float = 60.0, max_steps: int = 4, max_tool_calls: int = 8, temperature: float = 0.0, max_output_tokens: int = 900, history_chars: int = 24000, observation_chars: int = 20000, question_timeout: float = 180.0, ) -> Tuple[List[str], Dict[str, Any], str, List[str], bool]: """Run a bounded Track A agent loop around the trained ML model tool.""" sid = scenario.get("scenario_id") or scenario.get("ID") or "unknown" started = time.perf_counter() deadline = started + question_timeout if question_timeout and question_timeout > 0 else 0 model_calls = 0 invalid_json_count = 0 duplicate_skips = 0 timeout_hit = False seen_tool_calls = set() def run_ml_tool() -> Dict[str, Any]: labels, dbg = predict_labels(template_model, selector_model, scenario) labels = normalize_prediction_labels(labels) context = build_prediction_context(scenario) top_templates = dbg.get("top_templates", [])[:5] evidence = build_option_evidence( selector_model, scenario, context, top_templates, labels ) return { "scenario_id": sid, "prediction": "|".join(labels), "labels": labels, "template": dbg.get("template"), "template_prob": dbg.get("template_prob"), "top_templates": top_templates, "options": get_options(scenario), "option_evidence": evidence, } def build_dbg( ml_result: Dict[str, Any], agent_used: bool, final_raw: str = "", fallback_reason: str = "", changed: bool = False, ) -> Dict[str, Any]: dbg = { "template": ml_result.get("template"), "template_prob": ml_result.get("template_prob"), "top_templates": ml_result.get("top_templates", []), "agent_used": agent_used, "agent_changed_prediction": changed, "model_calls": model_calls, "invalid_json_count": invalid_json_count, "duplicate_skips": duplicate_skips, "timeout_hit": timeout_hit, } if final_raw: dbg["agent_final_raw"] = final_raw if fallback_reason: dbg["agent_fallback_reason"] = fallback_reason return dbg def labels_from_text(text: str, ml_result: Optional[Dict[str, Any]]) -> List[str]: parsed: Dict[str, Any] = {} try: parsed = json.loads(text) except Exception: pass final_text = str(parsed.get("final_answer") or text) labels = normalize_prediction_labels(re.findall(r"C\d+", final_text)) if ml_result: valid_ids = set((ml_result.get("options") or {}).keys()) labels = [label for label in labels if label in valid_ids] task_type = ( (ml_result.get("option_evidence") or {}).get("task_type") or ("multiple-answer" if is_multi_task(scenario) else "single-answer") ) bad_cardinality = ( (task_type == "single-answer" and len(labels) != 1) or (task_type == "multiple-answer" and not (2 <= len(labels) <= 4)) ) if bad_cardinality: return [] return labels if llm_client is None: result = run_ml_tool() labels = result["labels"] dbg = build_dbg(result, False, fallback_reason="agent_disabled") return labels, dbg, make_completion(labels, dbg, scenario), [ format_tool_call("run_trained_ml_model", {"scenario_id": sid}, result) ], False tools = [ { "type": "function", "function": { "name": "run_trained_ml_model", "description": ( "Run the already-trained Track A v4 ML pipeline on the hydrated " "scenario data. Returns the ML prior plus compact option evidence " "for agentic reranking." ), "parameters": { "type": "object", "properties": { "scenario_id": { "type": "string", "description": "Scenario ID to score.", } }, "required": ["scenario_id"], "additionalProperties": False, }, }, } ] messages = [ {"role": "system", "content": MODEL_DELEGATION_SYSTEM_PROMPT}, { "role": "user", "content": ( "Question and options:\n" f"{build_question_text(scenario)}\n\n" "Call run_trained_ml_model now. Then rerank only if the returned " "option evidence shows a better target cell, subtype, or numerical " "change than the ML prior. Return strict JSON with " '{"final_answer":"C...","used_tool":true,"reason":"short"}.' ), }, ] tool_trace: List[str] = [] ml_result: Optional[Dict[str, Any]] = None final_content = "" try: for step in range(1, max(1, max_steps) + 1): if _time_left(deadline) <= 2: timeout_hit = True break force_tool = ml_result is None model_calls += 1 request = { "model": model_name, "messages": messages, "tools": tools, "tool_choice": ( { "type": "function", "function": {"name": "run_trained_ml_model"}, } if force_tool else "auto" ), "temperature": temperature, "max_tokens": max_output_tokens, "timeout": _llm_timeout_left(deadline, timeout), } if not force_tool: request["response_format"] = {"type": "json_object"} response = llm_client.chat.completions.create(**request) msg = response.choices[0].message tool_calls = getattr(msg, "tool_calls", None) or [] if tool_calls: messages.append(msg.model_dump(exclude_none=True)) for call in tool_calls: name = call.function.name args = json.loads(call.function.arguments or "{}") call_key = (name, json.dumps(args, sort_keys=True)) if len(tool_trace) >= max_tool_calls: result = { "error": "tool_limit_reached", "message": "Return final answer using existing evidence.", } elif call_key in seen_tool_calls: duplicate_skips += 1 result = { "duplicate_tool_call_skipped": True, "message": "Use existing tool evidence and return final JSON.", } elif name != "run_trained_ml_model": result = {"error": f"unsupported tool {name}"} else: seen_tool_calls.add(call_key) result = run_ml_tool() ml_result = result tool_trace.append(format_tool_call(name, args, result)) messages.append( { "role": "tool", "tool_call_id": call.id, "name": name, "content": _json_tool_content(result, observation_chars), } ) if len(tool_trace) >= max_tool_calls: break continue final_content = msg.content or "" labels = labels_from_text(final_content, ml_result) if labels: ml_labels = normalize_prediction_labels( (ml_result or {}).get("labels", []) ) changed = bool(ml_result) and labels != ml_labels if not ml_result: ml_result = run_ml_tool() tool_trace.append( format_tool_call( "run_trained_ml_model", {"scenario_id": sid, "fallback": True}, ml_result, ) ) dbg = build_dbg( ml_result, True, final_raw=final_content, changed=changed, ) return labels, dbg, final_content, tool_trace, True invalid_json_count += 1 messages.append( { "role": "assistant", "content": _clip_text(final_content, history_chars), } ) messages.append( { "role": "user", "content": ( "The final answer was invalid. Return strict JSON only, " 'for example {"final_answer":"C5","used_tool":true,"reason":"short"}.' ), } ) if not ml_result: ml_result = run_ml_tool() tool_trace.append( format_tool_call( "run_trained_ml_model", {"scenario_id": sid, "fallback": True}, ml_result, ) ) labels = normalize_prediction_labels(ml_result.get("labels", [])) reason = ( "question_timeout" if timeout_hit else "max_tool_calls_or_steps_reached" ) dbg = build_dbg(ml_result, False, final_raw=final_content, fallback_reason=reason) completion = ( final_content if final_content.strip() else make_completion(labels, dbg, scenario) ) return labels, dbg, completion, tool_trace, False except Exception as exc: result = run_ml_tool() labels = normalize_prediction_labels(result["labels"]) dbg = build_dbg( result, False, fallback_reason=f"{type(exc).__name__}: {exc}", ) tool_trace.append( format_tool_call( "run_trained_ml_model", {"scenario_id": sid, "fallback_after_agent_error": True}, result, ) ) return labels, dbg, make_completion(labels, dbg, scenario), tool_trace, False def scenario_raw_text(s: Dict[str, Any], max_chars: int = 12000) -> str: """Compact raw text feature for selector. Truncated for speed.""" task = s.get("task", {}) or {} data = s.get("data", {}) or {} parts = [ task.get("description", ""), data.get("user_plane_data", "")[:max_chars], data.get("network_configuration_data", "")[:5000], data.get("signaling_plane_data", "")[:5000], data.get("traffic_data", "")[:5000], data.get("mr_data", "")[:5000], data.get("notes", ""), ] return " \n ".join(str(x) for x in parts if x) # ============================================================================= # Column constants # ============================================================================= COL_SERV = "5G KPI PCell RF Serving PCI" COL_RSRP = "5G KPI PCell RF Serving SS-RSRP [dBm]" COL_SINR = "5G KPI PCell RF Serving SS-SINR [dB]" COL_THR = "5G KPI PCell Layer2 MAC DL Throughput [Mbps]" COL_RB = "5G KPI PCell Layer1 DL RB Num (Including 0)" COL_MCS = "Avg MCS" COL_RANK = "Avg Rank" COL_GRANT = "Grant" COL_CCE_FAIL = "CCE Fail Rate" COL_IBLER = "Initial BLER(%)" COL_RBLER = "Residual BLER(%)" def get_dataframes( s: Dict[str, Any], ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: data = s.get("data", {}) or {} user = read_pipe_csv(data.get("user_plane_data")) config = read_pipe_csv(data.get("network_configuration_data")) signaling = read_pipe_csv(data.get("signaling_plane_data")) traffic = read_pipe_csv(data.get("traffic_data")) mr = read_pipe_csv(data.get("mr_data")) return user, config, signaling, traffic, mr def normalize_server_scenario( local_s: Dict[str, Any], remote_s: Dict[str, Any] ) -> Dict[str, Any]: merged = dict(local_s) for key in ["context", "tools", "expected_output", "steps"]: if remote_s.get(key) is not None: merged[key] = remote_s[key] if remote_s.get("data"): merged["data"] = remote_s["data"] if remote_s.get("answer") is not None: merged["answer"] = remote_s["answer"] return merged def server_get( client: httpx.Client, server_url: str, endpoint: str, scenario_id: str, params: Optional[Dict[str, Any]] = None, ) -> Any: url = f"{server_url.rstrip('/')}/{endpoint.lstrip('/')}" headers = {"X-Scenario-Id": scenario_id} resp = client.get(url, params=params or {}, headers=headers) resp.raise_for_status() return resp.json() def hydrate_scenario_from_server( s: Dict[str, Any], client: httpx.Client, server_url: str, try_scenario_endpoint: bool = False, ) -> Tuple[Dict[str, Any], List[str]]: sid = s.get("scenario_id") or s.get("ID") if not sid: raise ValueError("Cannot hydrate scenario without scenario_id/ID") tool_calls: List[str] = [] if try_scenario_endpoint: try: remote = server_get(client, server_url, "/scenario", sid) tool_calls.append( format_tool_call("get_scenario", {"scenario_id": sid}, remote) ) hydrated = normalize_server_scenario(s, remote) if not scenario_needs_server_data(hydrated): return hydrated, tool_calls except Exception as exc: tool_calls.append( format_tool_call( "get_scenario", {"scenario_id": sid}, {"error": f"{type(exc).__name__}: {exc}"}, ) ) data = dict((s.get("data") or {})) endpoint_map = [ ( "get_user_plane_data", "/user-plane-data", "User Plane Data", "user_plane_data", ), ( "get_config_data", "/config-data", "Network Configuration Data", "network_configuration_data", ), ("get_kpi_data", "/get_kpi_data", "Traffic Data", "traffic_data"), ("get_mr_data", "/get_mr_data", "MR Data", "mr_data"), ] for func_name, endpoint, response_key, data_key in endpoint_map: try: result = server_get(client, server_url, endpoint, sid) tool_calls.append(format_tool_call(func_name, {}, result)) if response_key in result: data[data_key] = result[response_key] except Exception as exc: tool_calls.append( format_tool_call( func_name, {}, {"error": f"{type(exc).__name__}: {exc}"}, ) ) if is_placeholder_data( data.get("signaling_plane_data") ) and not is_placeholder_data(data.get("user_plane_data")): user_df = read_pipe_csv(data.get("user_plane_data")) if not user_df.empty and "Timestamp" in user_df.columns: time_arg = str(user_df["Timestamp"].iloc[-1]) try: result = server_get( client, server_url, "/signaling-plane-event-log", sid, {"time": time_arg}, ) tool_calls.append( format_tool_call( "get_signaling_plane_event_log", {"time": time_arg}, result, ) ) if isinstance(result, str): data["signaling_plane_data"] = result except Exception as exc: tool_calls.append( format_tool_call( "get_signaling_plane_event_log", {"time": time_arg}, {"error": f"{type(exc).__name__}: {exc}"}, ) ) hydrated = dict(s) hydrated["data"] = data return hydrated, tool_calls def neighbor_col_pairs(user: pd.DataFrame) -> List[Tuple[str, str]]: pairs = [] if user is None or user.empty: return pairs for i in range(1, 6): pci = f"Measurement PCell Neighbor Cell Top Set(Cell Level) Top {i} PCI" rsrp = f"Measurement PCell Neighbor Cell Top Set(Cell Level) Top {i} Filtered Tx BRSRP [dBm]" if pci in user.columns and rsrp in user.columns: pairs.append((pci, rsrp)) return pairs def bad_rows(user: pd.DataFrame) -> pd.DataFrame: if user is None or user.empty: return pd.DataFrame() if COL_THR not in user.columns: return user.copy() u = user.copy() u["_thr"] = numeric_series(u, COL_THR) thr = u["_thr"] if not thr.notna().any(): return u.copy() q35 = float(np.nanpercentile(thr.dropna(), 35)) b = u[(thr < 100) | (thr <= q35)].copy() if b.empty: b = u.nsmallest(min(3, len(u)), "_thr").copy() return b def build_cell_maps(config: pd.DataFrame) -> Tuple[Dict[str, int], Dict[int, str]]: g2p, p2g = {}, {} if config is None or config.empty: return g2p, p2g for _, r in config.iterrows(): try: key = f"{int(float(r.get('gNodeB ID')))}_{int(float(r.get('Cell ID')))}" pci = int(float(r.get("PCI"))) g2p[key] = pci p2g[pci] = key except Exception: continue return g2p, p2g def haversine_m(lon1: float, lat1: float, lon2: float, lat2: float) -> float: radius = 6371000.0 dlat = math.radians(lat2 - lat1) dlon = math.radians(lon2 - lon1) a = ( math.sin(dlat / 2) ** 2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon / 2) ** 2 ) return 2 * radius * math.asin(math.sqrt(a)) def bearing_deg(lon1: float, lat1: float, lon2: float, lat2: float) -> float: lat1_rad = math.radians(lat1) lat2_rad = math.radians(lat2) dlon = math.radians(lon2 - lon1) x = math.sin(dlon) * math.cos(lat2_rad) y = math.cos(lat1_rad) * math.sin(lat2_rad) - math.sin(lat1_rad) * math.cos( lat2_rad ) * math.cos(dlon) return (math.degrees(math.atan2(x, y)) + 360) % 360 def angle_delta_deg(a: float, b: float) -> float: delta = abs(a - b) % 360 return 360 - delta if delta > 180 else delta # ============================================================================= # Action parsing and templates # ============================================================================= ACTION_ORDER = [ "server", "insufficient", "pdcch", "inc_power", "dec_power", "tilt_down", "tilt_up", "azimuth", "inc_a3", "dec_a3", "a2a5", "neighbor", ] def label_to_action(label: str) -> str: t = norm(label) if "insufficient data" in t or "more data" in t: return "insufficient" if "server" in t or "transmission issue" in t or "transmission issues" in t: return "server" if "pdcch" in t or "2sym" in t: return "pdcch" if "increase transmission power" in t: return "inc_power" if "decrease transmission power" in t: return "dec_power" if "press down" in t or "down the tilt" in t: return "tilt_down" if "lift" in t and "tilt" in t: return "tilt_up" if "azimuth" in t: return "azimuth" if "increase a3" in t: return "inc_a3" if "decrease a3" in t: return "dec_a3" if ( "covinterfreqa2" in t or "covinterfreqa5" in t or "a2rsrpthld" in t or "a5rsrpthld" in t ): return "a2a5" if "neighbor relationship" in t: return "neighbor" return "other" def answer_to_template(s: Dict[str, Any]) -> str: opts = get_options(s) ids = parse_answer(s.get("answer", "")) actions = [label_to_action(opts.get(cid, "")) for cid in ids] counts = Counter([a for a in actions if a and a != "other"]) parts = [] for a in ACTION_ORDER: parts.extend([a] * counts.get(a, 0)) return "|".join(parts) if parts else "other" def template_to_actions(template: str) -> List[str]: if not template or template == "other": return [] return [x for x in template.split("|") if x] def extract_gnb_cell(label: str) -> Optional[str]: m = re.search(r"\b(\d{5,8}_\d{1,3})\b", str(label)) return m.group(1) if m else None def extract_all_gnb_cells(label: str) -> List[str]: return re.findall(r"\b\d{5,8}_\d{1,3}\b", str(label)) def parse_pdcch_symbols(value: Any) -> float: m = re.search(r"(\d+)\s*sym", norm(value)) if not m: return np.nan return safe_float(m.group(1), np.nan) def parse_option_semantics(label: str) -> Dict[str, Any]: """Extract option facts without deciding whether the option is correct.""" text = str(label or "") t = norm(text) masked = re.sub(r"\b\d{5,8}_\d{1,3}\b", "", t) action = label_to_action(text) cells = extract_all_gnb_cells(text) amount = np.nan unit = "" threshold_types: List[str] = [] pdcch_symbols = np.nan if action == "pdcch": pdcch_symbols = parse_pdcch_symbols(text) amount = pdcch_symbols unit = "sym" elif action in {"inc_power", "dec_power"}: m = re.search(r"\bby\s*([-+]?\d+(?:\.\d+)?)\s*dbm\b", masked) if m: amount = safe_float(m.group(1), np.nan) unit = "dbm" elif action in {"tilt_down", "tilt_up", "azimuth"}: m = re.search(r"\bby\s*([-+]?\d+(?:\.\d+)?)\s*degrees?\b", masked) if m: amount = safe_float(m.group(1), np.nan) unit = "degrees" elif action in {"inc_a3", "dec_a3"}: m = re.search(r"\bby\s*([-+]?\d+(?:\.\d+)?)\s*db\b", masked) if m: amount = safe_float(m.group(1), np.nan) unit = "db" elif action == "a2a5": if "covinterfreqa2rsrpthld" in t: threshold_types.append("a2") if "covinterfreqa5rsrpthld1" in t: threshold_types.append("a5_1") if "covinterfreqa5rsrpthld2" in t: threshold_types.append("a5_2") m = re.search(r"\bby\s*([-+]?\d+(?:\.\d+)?)\s*db\b", masked) if m: amount = safe_float(m.group(1), np.nan) unit = "db" return { "action": action, "target_cell": cells[0] if cells else None, "cells": cells, "amount": amount, "unit": unit, "threshold_types": threshold_types, "threshold_type": "|".join(threshold_types), "pdcch_symbols": pdcch_symbols, "towards_ue": 1.0 if "towards the ue" in t else 0.0, } def extract_cell_geometry_stats(s: Dict[str, Any]) -> Dict[str, Dict[str, float]]: user, config, _, _, _ = get_dataframes(s) if user is None or user.empty or config is None or config.empty: return {} need_user = {"Longitude", "Latitude", COL_SERV} need_cfg = { "gNodeB ID", "Cell ID", "PCI", "Longitude", "Latitude", "Mechanical Azimuth", "Mechanical Downtilt", "Height", } if not need_user.issubset(user.columns) or not need_cfg.issubset(config.columns): return {} _, p2g = build_cell_maps(config) cfg_by_key: Dict[str, Dict[str, float]] = {} for _, r in config.iterrows(): try: key = f"{int(float(r.get('gNodeB ID')))}_{int(float(r.get('Cell ID')))}" cfg_by_key[key] = { "lon": safe_float(r.get("Longitude"), np.nan), "lat": safe_float(r.get("Latitude"), np.nan), "azimuth": safe_float(r.get("Mechanical Azimuth"), np.nan), "downtilt": safe_float(r.get("Mechanical Downtilt"), np.nan), "height": safe_float(r.get("Height"), np.nan), } except Exception: continue related_rows: Dict[str, List[pd.Series]] = defaultdict(list) b = bad_rows(user) for _, r in b.iterrows(): spci = safe_float(r.get(COL_SERV), np.nan) if not np.isnan(spci): key = p2g.get(int(spci)) if key: related_rows[key].append(r) for pci_c, _ in neighbor_col_pairs(user): npci = safe_float(r.get(pci_c), np.nan) if np.isnan(npci): continue key = p2g.get(int(npci)) if key: related_rows[key].append(r) out: Dict[str, Dict[str, float]] = {} for key, cfg in cfg_by_key.items(): rows = related_rows.get(key, []) if not rows: continue h_errors, desired_tilts, tilt_errors, distances = [], [], [], [] for r in rows: try: ulon = safe_float(r.get("Longitude"), np.nan) ulat = safe_float(r.get("Latitude"), np.nan) if np.isnan(ulon) or np.isnan(ulat): continue dist = haversine_m(cfg["lon"], cfg["lat"], ulon, ulat) brg = bearing_deg(cfg["lon"], cfg["lat"], ulon, ulat) hdiff = angle_delta_deg(brg, cfg["azimuth"]) desired_tilt = math.degrees(math.atan2(cfg["height"], max(dist, 0.1))) h_errors.append(hdiff) desired_tilts.append(desired_tilt) tilt_errors.append(abs(cfg["downtilt"] - desired_tilt)) distances.append(dist) except Exception: continue if h_errors: out[key] = clean_features( { "geom_related_bad_count": float(len(h_errors)), "geom_distance_mean": float(np.mean(distances)), "geom_azimuth_error_mean": float(np.mean(h_errors)), "geom_azimuth_error_p50": float(np.percentile(h_errors, 50)), "geom_azimuth_error_max": float(np.max(h_errors)), "geom_desired_downtilt_mean": float(np.mean(desired_tilts)), "geom_desired_downtilt_p50": float(np.percentile(desired_tilts, 50)), "geom_current_tilt_error_mean": float(np.mean(tilt_errors)), "geom_current_tilt_error_p50": float(np.percentile(tilt_errors, 50)), } ) return out # ============================================================================= # Scenario-level features for template classifier # ============================================================================= def estimate_mainlobe_flags(user: pd.DataFrame, config: pd.DataFrame) -> pd.Series: if user is None or user.empty or config is None or config.empty: return pd.Series(dtype=float) need_user = {"Longitude", "Latitude", COL_SERV} need_cfg = { "PCI", "Longitude", "Latitude", "Mechanical Azimuth", "Mechanical Downtilt", "Height", } if not need_user.issubset(user.columns) or not need_cfg.issubset(config.columns): return pd.Series([np.nan] * len(user), index=user.index) cfg = config.copy() cfg["PCI"] = pd.to_numeric(cfg["PCI"], errors="coerce") cfg = cfg.set_index("PCI", drop=False) flags = [] for _, r in user.iterrows(): try: pci = int(float(r[COL_SERV])) if pci not in cfg.index: flags.append(np.nan) continue c = cfg.loc[pci] ulon, ulat = float(r["Longitude"]), float(r["Latitude"]) clon, clat = float(c["Longitude"]), float(c["Latitude"]) az = float(c["Mechanical Azimuth"]) downtilt = float(c["Mechanical Downtilt"]) height = float(c["Height"]) lat1 = math.radians(clat) lat2 = math.radians(ulat) dlon = math.radians(ulon - clon) x = math.sin(dlon) * math.cos(lat2) y = math.cos(lat1) * math.sin(lat2) - math.sin(lat1) * math.cos( lat2 ) * math.cos(dlon) bearing = (math.degrees(math.atan2(x, y)) + 360) % 360 hdiff = abs(bearing - az) % 360 if hdiff > 180: hdiff = 360 - hdiff R = 6371000.0 dlat = math.radians(ulat - clat) dlon2 = math.radians(ulon - clon) a = ( math.sin(dlat / 2) ** 2 + math.cos(math.radians(clat)) * math.cos(math.radians(ulat)) * math.sin(dlon2 / 2) ** 2 ) dist = 2 * R * math.asin(math.sqrt(a)) tilt_angle = math.degrees(math.atan2(-height, max(dist, 0.1))) final_tilt = tilt_angle - downtilt flags.append(1.0 if (abs(hdiff) > 50 or abs(final_tilt) > 6) else 0.0) except Exception: flags.append(np.nan) return pd.Series(flags, index=user.index) def extract_scenario_features(s: Dict[str, Any]) -> Dict[str, float]: user, config, signaling, traffic, mr = get_dataframes(s) feats: Dict[str, float] = {} feats["is_multi"] = float(is_multi_task(s)) ctx = s.get("context", {}).get("wireless_network_information", {}) or {} feats["num_base_stations"] = safe_float(ctx.get("num_base_stations"), np.nan) if user is None or user.empty: feats["missing_user"] = 1.0 return clean_features(feats) feats["missing_user"] = 0.0 feats["n_user_rows"] = float(len(user)) for prefix, ser in { "thr": numeric_series(user, COL_THR), "rsrp": numeric_series(user, COL_RSRP), "sinr": numeric_series(user, COL_SINR), "rb": numeric_series(user, COL_RB), "mcs": numeric_series(user, COL_MCS), "rank": numeric_series(user, COL_RANK), "grant": numeric_series(user, COL_GRANT), "ccefail": numeric_series(user, COL_CCE_FAIL), "ibler": numeric_series(user, COL_IBLER), "rbler": numeric_series(user, COL_RBLER), }.items(): feats.update(stat_feats(prefix, ser)) b = bad_rows(user) b_rsrp = numeric_series(b, COL_RSRP) b_sinr = numeric_series(b, COL_SINR) b_mcs = numeric_series(b, COL_MCS) b_ibler = numeric_series(b, COL_IBLER) feats["bad_frac"] = float(len(b) / max(1, len(user))) for prefix, ser in { "bad_thr": numeric_series(b, COL_THR), "bad_rsrp": b_rsrp, "bad_sinr": b_sinr, "bad_rb": numeric_series(b, COL_RB), "bad_mcs": b_mcs, "bad_rank": numeric_series(b, COL_RANK), "bad_ccefail": numeric_series(b, COL_CCE_FAIL), "bad_ibler": b_ibler, "bad_rbler": numeric_series(b, COL_RBLER), }.items(): feats.update(stat_feats(prefix, ser)) feats["bad_rsrp_lt_m95_frac"] = ( float((b_rsrp < -95).mean()) if b_rsrp.notna().any() else np.nan ) feats["bad_rsrp_lt_m90_frac"] = ( float((b_rsrp < -90).mean()) if b_rsrp.notna().any() else np.nan ) feats["bad_sinr_lt_0_frac"] = ( float((b_sinr < 0).mean()) if b_sinr.notna().any() else np.nan ) feats["bad_sinr_lt_5_frac"] = ( float((b_sinr < 5).mean()) if b_sinr.notna().any() else np.nan ) feats["good_rsrp_bad_sinr_frac"] = ( float(((b_rsrp > -90) & (b_sinr < 5)).mean()) if len(b) else np.nan ) feats["weak_rsrp_good_sinr_frac"] = ( float(((b_rsrp < -95) & (b_sinr > 8)).mean()) if len(b) else np.nan ) feats["low_mcs_frac"] = ( float((b_mcs < 10).mean()) if b_mcs.notna().any() else np.nan ) feats["high_bler_frac"] = ( float((b_ibler > 15).mean()) if b_ibler.notna().any() else np.nan ) if COL_SERV in user.columns: serv_all = pd.to_numeric(user[COL_SERV], errors="coerce") serv_bad = pd.to_numeric( b.get(COL_SERV, pd.Series(dtype=float)), errors="coerce" ) feats["n_serving_pci_all"] = float(serv_all.nunique(dropna=True)) feats["n_serving_pci_bad"] = float(serv_bad.nunique(dropna=True)) feats["dominant_bad_serv_frac"] = ( float(serv_bad.value_counts(normalize=True).iloc[0]) if serv_bad.notna().any() else np.nan ) stronger_counts, margins, neighbor_counts = [], [], [] for _, r in b.iterrows(): srv = safe_float(r.get(COL_RSRP)) row_stronger, row_count, best = 0, 0, -999.0 for pci_c, rsrp_c in neighbor_col_pairs(user): nrsrp = safe_float(r.get(rsrp_c)) if np.isnan(nrsrp): continue row_count += 1 best = max(best, nrsrp) if not np.isnan(srv) and nrsrp > srv + 2: row_stronger += 1 stronger_counts.append(row_stronger) neighbor_counts.append(row_count) margins.append(best - srv if row_count and not np.isnan(srv) else np.nan) feats["bad_neighbor_stronger_mean"] = ( float(np.nanmean(stronger_counts)) if stronger_counts else np.nan ) feats["bad_neighbor_count_mean"] = ( float(np.nanmean(neighbor_counts)) if neighbor_counts else np.nan ) feats["bad_best_neighbor_margin_mean"] = ( float(np.nanmean(margins)) if margins else np.nan ) feats["bad_best_neighbor_margin_max"] = ( float(np.nanmax(margins)) if margins else np.nan ) outside = estimate_mainlobe_flags(user, config) if len(outside): feats["outside_mainlobe_frac_all"] = float( pd.to_numeric(outside, errors="coerce").mean() ) feats["outside_mainlobe_frac_bad"] = ( float(pd.to_numeric(outside.loc[b.index], errors="coerce").mean()) if len(b) else np.nan ) if config is not None and not config.empty: feats["n_config_cells"] = float(len(config)) for col in [ "Transmission Power", "Mechanical Downtilt", "Digital Tilt", "Height", "IntraFreqHoA3Offset [0.5dB]", ]: if col in config.columns: key = "cfg_" + re.sub(r"[^a-zA-Z0-9]+", "_", col).strip("_").lower() feats.update(stat_feats(key, numeric_series(config, col))) if "PdcchOccupiedSymbolNum" in config.columns: feats["cfg_pdcch_1sym_frac"] = float( config["PdcchOccupiedSymbolNum"] .astype(str) .str.lower() .str.contains("1sym") .mean() ) if signaling is not None and not signaling.empty: event_names = " ".join( "" if pd.isna(x) else str(x) for x in signaling.get("Event Name", pd.Series(dtype=object)).tolist() ).lower() event_content = " ".join( "" if pd.isna(x) else str(x) for x in signaling.get("Event Content", pd.Series(dtype=object)).tolist() ).lower() text = event_names + " " + event_content feats["sig_rows"] = float(len(signaling)) feats["sig_a3_count"] = float(text.count("eventa3")) feats["sig_a2_count"] = float(text.count("eventa2")) feats["sig_ho_attempt_count"] = float(text.count("handoverattempt")) feats["sig_reest_count"] = float(text.count("reestablish")) feats["sig_ra_attempt_count"] = float(text.count("randomaccessattempt")) feats["sig_ra_success_count"] = float(text.count("randomaccesssuc")) sig_margins = [] for content in signaling.get("Event Content", pd.Series(dtype=object)).tolist(): txt = "" if pd.isna(content) else str(content) sm = re.search(r"ServCellRSRP:?\s*(-?\d+)", txt) nm = re.search(r"NCellRSRP:?\s*(-?\d+)", txt) if sm and nm: sig_margins.append(float(nm.group(1)) - float(sm.group(1))) feats["sig_a3_margin_mean"] = ( float(np.mean(sig_margins)) if sig_margins else np.nan ) feats["sig_a3_margin_max"] = ( float(np.max(sig_margins)) if sig_margins else np.nan ) else: feats["sig_rows"] = 0.0 if traffic is not None and not traffic.empty: feats["traffic_rows"] = float(len(traffic)) for col in [ "Uplink PRB utilization(%)", "Downlink PRB utilization(%)", "Uplink PRB Interference(dBm)", "User Uplink Throughput(Mbps)", "User Downlink Throughput(Mbps)", "Downlink Weak Coversge Ratio", "TA>1KM Ratio", "Uplink CCE utilization(%)", "Downlink CCE utilization(%)", "Uplink CCE Allocation Success Rate(%)", "Downlink CCE Allocation Success Rate(%)", ]: if col in traffic.columns: key = "traffic_" + re.sub(r"[^a-zA-Z0-9]+", "_", col).strip("_").lower() feats.update(stat_feats(key, numeric_series(traffic, col))) else: feats["traffic_rows"] = 0.0 if mr is not None and not mr.empty: feats["mr_rows"] = float(len(mr)) for col in [ "Serving RSRP(dBm)", "Throughput(Mbps)", "Neighbor 1 RSRP(dBm)", "Neighbor 2 RSRP(dBm)", "Neighbor 3 RSRP(dBm)", ]: if col in mr.columns: key = "mr_" + re.sub(r"[^a-zA-Z0-9]+", "_", col).strip("_").lower() feats.update(stat_feats(key, numeric_series(mr, col))) neigh_margins = [] for _, r in mr.iterrows(): sr = safe_float(r.get("Serving RSRP(dBm)")) best, found = -999.0, False for i in range(1, 4): nr = safe_float(r.get(f"Neighbor {i} RSRP(dBm)")) if not np.isnan(nr): best = max(best, nr) found = True if found and not np.isnan(sr): neigh_margins.append(best - sr) feats["mr_best_neighbor_margin_mean"] = ( float(np.mean(neigh_margins)) if neigh_margins else np.nan ) feats["mr_neighbor_stronger_frac"] = ( float(np.mean([m > 2 for m in neigh_margins])) if neigh_margins else np.nan ) else: feats["mr_rows"] = 0.0 return clean_features(feats) # ============================================================================= # Generic target-cell features for selector model # ============================================================================= def extract_cell_stats(s: Dict[str, Any]) -> Dict[str, Dict[str, float]]: user, config, signaling, traffic, mr = get_dataframes(s) g2p, p2g = build_cell_maps(config) stats = defaultdict(lambda: defaultdict(float)) if user is not None and not user.empty: b = bad_rows(user) for _, r in b.iterrows(): spci = safe_float(r.get(COL_SERV)) srv_rsrp = safe_float(r.get(COL_RSRP)) if not np.isnan(spci): key = p2g.get(int(spci)) if key: stats[key]["bad_serv_count"] += 1.0 stats[key]["bad_serv_rsrp_sum"] += safe_float(r.get(COL_RSRP), 0.0) stats[key]["bad_serv_sinr_sum"] += safe_float(r.get(COL_SINR), 0.0) stats[key]["bad_serv_thr_sum"] += safe_float(r.get(COL_THR), 0.0) stats[key]["bad_serv_mcs_sum"] += safe_float(r.get(COL_MCS), 0.0) stats[key]["bad_serv_bler_sum"] += safe_float(r.get(COL_IBLER), 0.0) for pci_c, rsrp_c in neighbor_col_pairs(user): npci = safe_float(r.get(pci_c)) nrsrp = safe_float(r.get(rsrp_c)) if np.isnan(npci): continue key = p2g.get(int(npci)) if key: stats[key]["bad_nei_count"] += 1.0 if not np.isnan(nrsrp): stats[key]["bad_nei_rsrp_sum"] += nrsrp if not np.isnan(srv_rsrp): stats[key]["bad_nei_margin_sum"] += nrsrp - srv_rsrp if nrsrp > srv_rsrp + 2: stats[key]["bad_nei_stronger_count"] += 1.0 n_bad = max(1, len(b)) for key in list(stats.keys()): stats[key]["bad_serv_frac"] = stats[key].get("bad_serv_count", 0.0) / n_bad stats[key]["bad_nei_frac"] = stats[key].get("bad_nei_count", 0.0) / n_bad if config is not None and not config.empty: for _, r in config.iterrows(): try: key = f"{int(float(r.get('gNodeB ID')))}_{int(float(r.get('Cell ID')))}" stats[key]["cfg_power"] = safe_float( r.get("Transmission Power"), np.nan ) stats[key]["cfg_downtilt"] = safe_float( r.get("Mechanical Downtilt"), np.nan ) stats[key]["cfg_digital_tilt"] = safe_float( r.get("Digital Tilt"), np.nan ) stats[key]["cfg_height"] = safe_float(r.get("Height"), np.nan) stats[key]["cfg_a3_offset"] = safe_float( r.get("IntraFreqHoA3Offset [0.5dB]"), np.nan ) stats[key]["cfg_a2_threshold"] = safe_float( r.get("CovInterFreqA2RsrpThld [dBm]"), np.nan ) stats[key]["cfg_a5_threshold_1"] = safe_float( r.get("CovInterFreqA5RsrpThld1 [dBm]"), np.nan ) stats[key]["cfg_a5_threshold_2"] = safe_float( r.get("CovInterFreqA5RsrpThld2 [dBm]"), np.nan ) stats[key]["cfg_max_power"] = safe_float( r.get("Max Transmit Power"), np.nan ) stats[key]["cfg_pdcch_symbols"] = parse_pdcch_symbols( r.get("PdcchOccupiedSymbolNum", "") ) stats[key]["cfg_pdcch_1sym"] = ( 1.0 if "1sym" in norm(r.get("PdcchOccupiedSymbolNum", "")) else 0.0 ) except Exception: continue if traffic is not None and not traffic.empty: for _, r in traffic.iterrows(): try: key = f"{int(float(r.get('gNodeB_ID')))}_{int(float(r.get('Cell_ID')))}" stats[key]["traffic_dl_prb"] = safe_float( r.get("Downlink PRB utilization(%)"), np.nan ) stats[key]["traffic_ul_prb"] = safe_float( r.get("Uplink PRB utilization(%)"), np.nan ) stats[key]["traffic_ul_interf"] = safe_float( r.get("Uplink PRB Interference(dBm)"), np.nan ) stats[key]["traffic_dl_tp"] = safe_float( r.get("User Downlink Throughput(Mbps)"), np.nan ) stats[key]["traffic_weak_cov"] = safe_float( r.get("Downlink Weak Coversge Ratio"), np.nan ) stats[key]["traffic_ta_gt_1km"] = safe_float( r.get("TA>1KM Ratio"), np.nan ) stats[key]["traffic_cce_dl"] = safe_float( r.get("Downlink CCE utilization(%)"), np.nan ) stats[key]["traffic_cce_succ_dl"] = safe_float( r.get("Downlink CCE Allocation Success Rate(%)"), np.nan ) except Exception: continue if mr is not None and not mr.empty: for _, r in mr.iterrows(): spci = safe_float(r.get("Serving PCI")) if not np.isnan(spci): key = p2g.get(int(spci)) if key: stats[key]["mr_serv_count"] += 1.0 stats[key]["mr_serv_rsrp_sum"] += safe_float( r.get("Serving RSRP(dBm)"), 0.0 ) stats[key]["mr_thr_sum"] += safe_float( r.get("Throughput(Mbps)"), 0.0 ) for i in range(1, 4): npci = safe_float(r.get(f"Neighbor {i} PCI")) if np.isnan(npci): continue key = p2g.get(int(npci)) if key: stats[key]["mr_nei_count"] += 1.0 stats[key]["mr_nei_rsrp_sum"] += safe_float( r.get(f"Neighbor {i} RSRP(dBm)"), 0.0 ) out = {} for key, d in stats.items(): dd = dict(d) n = dd.get("bad_serv_count", 0.0) if n: dd["bad_serv_rsrp_mean"] = dd["bad_serv_rsrp_sum"] / n dd["bad_serv_sinr_mean"] = dd["bad_serv_sinr_sum"] / n dd["bad_serv_thr_mean"] = dd["bad_serv_thr_sum"] / n dd["bad_serv_mcs_mean"] = dd["bad_serv_mcs_sum"] / n dd["bad_serv_bler_mean"] = dd["bad_serv_bler_sum"] / n n = dd.get("bad_nei_count", 0.0) if n: dd["bad_nei_rsrp_mean"] = dd["bad_nei_rsrp_sum"] / n dd["bad_nei_margin_mean"] = dd["bad_nei_margin_sum"] / n n = dd.get("mr_serv_count", 0.0) if n: dd["mr_serv_rsrp_mean"] = dd["mr_serv_rsrp_sum"] / n dd["mr_thr_mean"] = dd["mr_thr_sum"] / n n = dd.get("mr_nei_count", 0.0) if n: dd["mr_nei_rsrp_mean"] = dd["mr_nei_rsrp_sum"] / n out[key] = clean_features(dd) return out def option_feature_dict( s: Dict[str, Any], cid: str, label: str, template: str, action: str, context: Optional[Dict[str, Any]] = None, ) -> Dict[str, float]: feats = {} context = context or {} scen = context.get("scenario_features") if scen is None: scen = extract_scenario_features(s) # Prefix scenario features to share with selector. for k, v in scen.items(): feats[f"scen_{k}"] = v is_multi = context.get("is_multi") if is_multi is None: is_multi = is_multi_task(s) feats["is_multi"] = float(is_multi) feats[f"action_{action}"] = 1.0 feats[f"template_{template}"] = 1.0 template_actions = template_to_actions(template) template_action_counts = Counter(template_actions) feats["action_in_template"] = float(action in template_action_counts) feats["template_action_count"] = float(len(template_actions)) feats["template_this_action_count"] = float(template_action_counts.get(action, 0)) semantics = (context.get("option_semantics") or {}).get(cid) if semantics is None: semantics = parse_option_semantics(label) amount = safe_float(semantics.get("amount"), np.nan) feats["option_has_amount"] = float(not np.isnan(amount)) feats["option_amount"] = amount feats["option_amount_abs"] = abs(amount) if not np.isnan(amount) else np.nan feats["option_towards_ue"] = safe_float(semantics.get("towards_ue"), 0.0) pdcch_symbols = safe_float(semantics.get("pdcch_symbols"), np.nan) feats["option_pdcch_symbols"] = pdcch_symbols threshold_types = set(semantics.get("threshold_types") or []) feats["option_threshold_a2"] = float("a2" in threshold_types) feats["option_threshold_a5_1"] = float("a5_1" in threshold_types) feats["option_threshold_a5_2"] = float("a5_2" in threshold_types) feats["option_threshold_combined"] = float(len(threshold_types) > 1) target = extract_gnb_cell(label) targets = extract_all_gnb_cells(label) feats["has_target_cell"] = float(target is not None) feats["num_cells_in_label"] = float(len(targets)) roles = context.get("cell_stats") if roles is None: roles = extract_cell_stats(s) geom_roles = context.get("cell_geometry_stats") if geom_roles is None: geom_roles = extract_cell_geometry_stats(s) if target and target in roles: for k, v in roles[target].items(): feats[f"target_{k}"] = v for k, v in geom_roles.get(target, {}).items(): feats[f"target_{k}"] = v target_stats = roles[target] geom_stats = geom_roles.get(target, {}) serv_rsrp = safe_float(target_stats.get("bad_serv_rsrp_mean"), np.nan) nei_margin = safe_float(target_stats.get("bad_nei_margin_mean"), np.nan) ul_interf = safe_float(target_stats.get("traffic_ul_interf"), np.nan) feats["target_rsrp_deficit_to_m95"] = ( -95.0 - serv_rsrp if not np.isnan(serv_rsrp) else np.nan ) feats["target_neighbor_excess_margin"] = ( nei_margin - 2.0 if not np.isnan(nei_margin) else np.nan ) feats["target_ul_interference_over_m115"] = ( ul_interf + 115.0 if not np.isnan(ul_interf) else np.nan ) if not np.isnan(amount): if action in {"inc_power", "dec_power"}: current = safe_float(target_stats.get("cfg_power"), np.nan) delta = amount if action == "inc_power" else -amount proposed = current + delta if not np.isnan(current) else np.nan feats["target_proposed_power"] = proposed max_power = safe_float(target_stats.get("cfg_max_power"), np.nan) feats["target_power_margin_after"] = ( max_power - proposed if not np.isnan(max_power) and not np.isnan(proposed) else np.nan ) elif action in {"tilt_down", "tilt_up"}: current = safe_float(target_stats.get("cfg_downtilt"), np.nan) delta = amount if action == "tilt_down" else -amount proposed = current + delta if not np.isnan(current) else np.nan feats["target_proposed_downtilt"] = proposed desired = safe_float(geom_stats.get("geom_desired_downtilt_p50"), np.nan) if not np.isnan(proposed) and not np.isnan(desired): current_error = abs(current - desired) proposed_error = abs(proposed - desired) feats["target_proposed_tilt_error"] = proposed_error feats["target_tilt_error_improvement"] = ( current_error - proposed_error ) elif action == "azimuth": required = safe_float(geom_stats.get("geom_azimuth_error_p50"), np.nan) if not np.isnan(required): feats["target_required_azimuth_change"] = required feats["target_azimuth_change_abs_error"] = abs(amount - required) elif action in {"inc_a3", "dec_a3"}: current_half_db = safe_float(target_stats.get("cfg_a3_offset"), np.nan) current_db = ( current_half_db / 2.0 if not np.isnan(current_half_db) else np.nan ) delta = amount if action == "inc_a3" else -amount feats["target_current_a3_offset_db"] = current_db feats["target_proposed_a3_offset_db"] = ( current_db + delta if not np.isnan(current_db) else np.nan ) elif action == "a2a5": if "a2" in threshold_types: current = safe_float(target_stats.get("cfg_a2_threshold"), np.nan) feats["target_proposed_a2_threshold"] = ( current - amount if not np.isnan(current) else np.nan ) if "a5_1" in threshold_types: current = safe_float(target_stats.get("cfg_a5_threshold_1"), np.nan) feats["target_proposed_a5_threshold_1"] = ( current - amount if not np.isnan(current) else np.nan ) if "a5_2" in threshold_types: current = safe_float(target_stats.get("cfg_a5_threshold_2"), np.nan) feats["target_proposed_a5_threshold_2"] = ( current - amount if not np.isnan(current) else np.nan ) elif action == "pdcch": current = safe_float(target_stats.get("cfg_pdcch_symbols"), np.nan) feats["target_current_pdcch_symbols"] = current feats["target_pdcch_symbol_delta"] = ( pdcch_symbols - current if not np.isnan(pdcch_symbols) and not np.isnan(current) else np.nan ) else: feats["target_missing"] = 1.0 # Neighbor pair options: aggregate both cells. if len(targets) >= 2: vals = [] for key in targets: r = roles.get(key, {}) vals.append(r.get("bad_serv_frac", 0.0) + r.get("bad_nei_frac", 0.0)) feats["pair_role_sum"] = float(np.sum(vals)) feats["pair_role_max"] = float(np.max(vals)) if vals else 0.0 # Option id as weak positional feature. try: feats["cid_num"] = float(int(cid[1:])) except Exception: feats["cid_num"] = 0.0 return clean_features(feats) def build_prediction_context(s: Dict[str, Any]) -> Dict[str, Any]: opts = get_options(s) return { "scenario_features": extract_scenario_features(s), "cell_stats": extract_cell_stats(s), "cell_geometry_stats": extract_cell_geometry_stats(s), "is_multi": is_multi_task(s), "options": opts, "option_actions": {cid: label_to_action(label) for cid, label in opts.items()}, "option_semantics": { cid: parse_option_semantics(label) for cid, label in opts.items() }, } def predict_template_probs( model: Pipeline, s: Dict[str, Any], context: Optional[Dict[str, Any]] = None ) -> List[Tuple[str, float]]: context = context or {} x = context.get("scenario_features") if x is None: x = extract_scenario_features(s) clf = model.named_steps["clf"] proba = model.predict_proba([x])[0] pairs = [(str(cls), float(p)) for cls, p in zip(clf.classes_, proba)] pairs.sort(key=lambda z: z[1], reverse=True) return pairs def selector_score(model: Pipeline, feats: Dict[str, float]) -> float: try: proba = model.predict_proba([feats])[0] classes = list(model.named_steps["clf"].classes_) if 1 in classes: return float(proba[classes.index(1)]) return 0.0 except Exception: return 0.0 def selector_scores(model: Pipeline, features: List[Dict[str, float]]) -> List[float]: if not features: return [] try: proba = model.predict_proba(features) classes = list(model.named_steps["clf"].classes_) if 1 not in classes: return [0.0] * len(features) idx = classes.index(1) return [float(row[idx]) for row in proba] except Exception: return [0.0] * len(features) def jsonable_number(value: Any, digits: int = 3) -> Optional[float]: x = safe_float(value, np.nan) if np.isnan(x): return None return round(float(x), digits) def compact_cell_evidence(stats: Dict[str, float]) -> Dict[str, Optional[float]]: keys = [ "bad_serv_frac", "bad_nei_frac", "bad_serv_rsrp_mean", "bad_serv_sinr_mean", "bad_serv_thr_mean", "bad_nei_margin_mean", "mr_serv_count", "mr_nei_count", "mr_thr_mean", "cfg_power", "cfg_max_power", "cfg_downtilt", "cfg_a3_offset", "cfg_a2_threshold", "cfg_a5_threshold_1", "cfg_a5_threshold_2", "cfg_pdcch_symbols", "traffic_dl_prb", "traffic_weak_cov", "traffic_cce_dl", "traffic_cce_succ_dl", "geom_related_bad_count", "geom_distance_mean", "geom_azimuth_error_p50", "geom_desired_downtilt_p50", "geom_current_tilt_error_p50", ] out: Dict[str, Optional[float]] = {} for key in keys: if key in stats: value = jsonable_number(stats.get(key)) if value is not None: out[key] = value return out def proposed_change_summary( action: str, semantics: Dict[str, Any], target_stats: Dict[str, float] ) -> Dict[str, Optional[float]]: amount = safe_float(semantics.get("amount"), np.nan) if np.isnan(amount): return {} out: Dict[str, Optional[float]] = {} if action in {"inc_power", "dec_power"}: current = safe_float(target_stats.get("cfg_power"), np.nan) max_power = safe_float(target_stats.get("cfg_max_power"), np.nan) delta = amount if action == "inc_power" else -amount proposed = current + delta if not np.isnan(current) else np.nan out["current_power"] = jsonable_number(current) out["proposed_power"] = jsonable_number(proposed) out["power_margin_after"] = ( jsonable_number(max_power - proposed) if not np.isnan(max_power) and not np.isnan(proposed) else None ) elif action in {"tilt_down", "tilt_up"}: current = safe_float(target_stats.get("cfg_downtilt"), np.nan) delta = amount if action == "tilt_down" else -amount desired = safe_float(target_stats.get("geom_desired_downtilt_p50"), np.nan) proposed = current + delta if not np.isnan(current) else np.nan out["current_downtilt"] = jsonable_number(current) out["proposed_downtilt"] = jsonable_number(proposed) out["desired_downtilt_p50"] = jsonable_number(desired) if not np.isnan(current) and not np.isnan(proposed) and not np.isnan(desired): out["tilt_error_improvement"] = jsonable_number( abs(current - desired) - abs(proposed - desired) ) elif action == "azimuth": required = safe_float(target_stats.get("geom_azimuth_error_p50"), np.nan) out["required_azimuth_change_p50"] = jsonable_number(required) out["azimuth_change_abs_error"] = ( jsonable_number(abs(amount - required)) if not np.isnan(required) else None ) elif action in {"inc_a3", "dec_a3"}: current_half_db = safe_float(target_stats.get("cfg_a3_offset"), np.nan) current_db = current_half_db / 2.0 if not np.isnan(current_half_db) else np.nan delta = amount if action == "inc_a3" else -amount out["current_a3_offset_db"] = jsonable_number(current_db) out["proposed_a3_offset_db"] = ( jsonable_number(current_db + delta) if not np.isnan(current_db) else None ) elif action == "a2a5": threshold_types = set(semantics.get("threshold_types") or []) if "a2" in threshold_types: current = safe_float(target_stats.get("cfg_a2_threshold"), np.nan) out["current_a2_threshold"] = jsonable_number(current) out["proposed_a2_threshold"] = ( jsonable_number(current - amount) if not np.isnan(current) else None ) if "a5_1" in threshold_types: current = safe_float(target_stats.get("cfg_a5_threshold_1"), np.nan) out["current_a5_threshold_1"] = jsonable_number(current) out["proposed_a5_threshold_1"] = ( jsonable_number(current - amount) if not np.isnan(current) else None ) if "a5_2" in threshold_types: current = safe_float(target_stats.get("cfg_a5_threshold_2"), np.nan) out["current_a5_threshold_2"] = jsonable_number(current) out["proposed_a5_threshold_2"] = ( jsonable_number(current - amount) if not np.isnan(current) else None ) elif action == "pdcch": current = safe_float(target_stats.get("cfg_pdcch_symbols"), np.nan) proposed = safe_float(semantics.get("pdcch_symbols"), np.nan) out["current_pdcch_symbols"] = jsonable_number(current) out["proposed_pdcch_symbols"] = jsonable_number(proposed) out["pdcch_symbol_delta"] = ( jsonable_number(proposed - current) if not np.isnan(proposed) and not np.isnan(current) else None ) return {k: v for k, v in out.items() if v is not None} def build_option_evidence( selector_model: Pipeline, s: Dict[str, Any], context: Dict[str, Any], top_templates: List[Tuple[str, float]], ml_labels: List[str], ) -> Dict[str, Any]: opts = context["options"] option_actions = context["option_actions"] option_semantics = context.get("option_semantics") or {} roles = context.get("cell_stats") or {} geom_roles = context.get("cell_geometry_stats") or {} feature_rows: List[Dict[str, float]] = [] feature_keys: List[Tuple[str, str, str]] = [] seen_keys = set() for tpl, _ in top_templates: tpl_actions = set(template_to_actions(tpl)) for cid, label in opts.items(): action = option_actions.get(cid, "other") if action == "other": continue key = (tpl, action, cid) if key in seen_keys: continue seen_keys.add(key) feature_keys.append(key) feature_rows.append(option_feature_dict(s, cid, label, tpl, action, context)) score_by_option: Dict[str, float] = defaultdict(float) score_detail: Dict[str, List[Dict[str, Any]]] = defaultdict(list) for (tpl, action, cid), score in zip(feature_keys, selector_scores(selector_model, feature_rows)): score_by_option[cid] = max(score_by_option.get(cid, 0.0), float(score)) score_detail[cid].append( { "template": tpl, "action": action, "selector_score": round(float(score), 5), } ) rows = [] for cid, label in sorted(opts.items(), key=lambda x: int(x[0][1:])): action = option_actions.get(cid, "other") semantics = option_semantics.get(cid) or parse_option_semantics(label) target = semantics.get("target_cell") target_stats = dict(roles.get(target, {}) if target else {}) if target: target_stats.update(geom_roles.get(target, {})) row = { "id": cid, "label": label, "action": action, "target_cell": target, "cells": semantics.get("cells") or [], "amount": jsonable_number(semantics.get("amount")), "unit": semantics.get("unit") or "", "threshold_type": semantics.get("threshold_type") or "", "pdcch_symbols": jsonable_number(semantics.get("pdcch_symbols")), "towards_ue": bool(semantics.get("towards_ue")), "selector_score": round(float(score_by_option.get(cid, 0.0)), 5), "in_ml_prediction": cid in set(ml_labels), "target_evidence": compact_cell_evidence(target_stats), "proposed_change": proposed_change_summary(action, semantics, target_stats), } details = score_detail.get(cid, []) if details: row["score_detail"] = sorted( details, key=lambda d: d["selector_score"], reverse=True )[:3] rows.append(row) ambiguity_groups = [] grouped: Dict[Tuple[str, Tuple[str, ...]], List[Dict[str, Any]]] = defaultdict(list) for row in rows: cells = tuple(row.get("cells") or []) if not cells or row["action"] in {"other", "server", "insufficient"}: continue grouped[(row["action"], cells)].append(row) for (action, cells), group_rows in grouped.items(): if len(group_rows) < 2: continue amounts = {r.get("amount") for r in group_rows} threshold_types = {r.get("threshold_type") for r in group_rows} pdcch_symbols = {r.get("pdcch_symbols") for r in group_rows} distinguishes = [] if len(amounts) > 1: distinguishes.append("amount") if len(threshold_types) > 1: distinguishes.append("threshold_type") if len(pdcch_symbols) > 1: distinguishes.append("pdcch_symbols") ambiguity_groups.append( { "action": action, "cells": list(cells), "candidate_ids": [r["id"] for r in group_rows], "distinguishes": distinguishes, "contains_ml_prediction": any(r["id"] in set(ml_labels) for r in group_rows), } ) return { "task_type": "multiple-answer" if context.get("is_multi") else "single-answer", "candidate_options": rows, "ambiguity_groups": ambiguity_groups, } def option_diversity_key(row: Dict[str, Any]) -> Tuple[Any, ...]: semantics = row.get("semantics") or {} cells = tuple(semantics.get("cells") or []) return ( row.get("action"), cells[:1], semantics.get("threshold_type") or "", ) def select_ranked_options( ranked: List[Dict[str, Any]], template: str, is_multi: bool ) -> List[str]: ranked = sorted(ranked, key=lambda r: (float(r.get("score", 0.0)), r["cid"]), reverse=True) if not ranked: return [] if not is_multi: return [ranked[0]["cid"]] actions = template_to_actions(template) target_count = len(actions) if 2 <= len(actions) <= 4 else 2 action_counts = Counter(actions) selected: List[str] = [] used = set() used_keys = set() for action, n_needed in action_counts.items(): action_rows = [r for r in ranked if r["action"] == action and r["cid"] not in used] for row in action_rows[:n_needed]: selected.append(row["cid"]) used.add(row["cid"]) used_keys.add(option_diversity_key(row)) for allow_duplicate in (False, True): for row in ranked: if len(selected) >= target_count: break if row["cid"] in used: continue if row["action"] in {"server", "insufficient", "other"} and len(ranked) > target_count: continue key = option_diversity_key(row) if not allow_duplicate and key in used_keys: continue selected.append(row["cid"]) used.add(row["cid"]) used_keys.add(key) if len(selected) >= target_count: break return selected[:4] def ranked_options_with_selector( selector: Pipeline, s: Dict[str, Any], template: str, context: Dict[str, Any], ) -> List[Dict[str, Any]]: opts = context.get("options") or get_options(s) option_actions = context.get("option_actions") or { cid: label_to_action(label) for cid, label in opts.items() } option_semantics = context.get("option_semantics") or {} ranked = [] for cid, label in opts.items(): action = option_actions.get(cid, "other") if action == "other": continue feats = option_feature_dict(s, cid, label, template, action, context) ranked.append( { "cid": cid, "action": action, "score": selector_score(selector, feats), "semantics": option_semantics.get(cid) or parse_option_semantics(label), } ) return ranked def ranked_options_from_scores( context: Dict[str, Any], template: str, score_lookup: Dict[Tuple[str, str, str], float], ) -> List[Dict[str, Any]]: opts = context.get("options") or {} option_actions = context.get("option_actions") or {} option_semantics = context.get("option_semantics") or {} ranked = [] for cid, label in opts.items(): action = option_actions.get(cid, "other") if action == "other": continue ranked.append( { "cid": cid, "action": action, "score": score_lookup.get((template, action, cid), 0.0), "semantics": option_semantics.get(cid) or parse_option_semantics(label), } ) return ranked def selected_rank_score(labels: List[str], ranked: List[Dict[str, Any]]) -> float: if not labels: return -1.0 scores = {row["cid"]: float(row.get("score", 0.0)) for row in ranked} vals = [scores.get(cid, 0.0) for cid in labels] return float(np.mean(vals) + 0.25 * np.min(vals)) def choose_options_with_selector( selector: Pipeline, s: Dict[str, Any], template: str, context: Optional[Dict[str, Any]] = None, ) -> List[str]: context = context or {} is_multi = context.get("is_multi") if is_multi is None: is_multi = is_multi_task(s) ranked = ranked_options_with_selector(selector, s, template, context) return select_ranked_options(ranked, template, bool(is_multi)) def choose_options_from_scores( s: Dict[str, Any], template: str, context: Dict[str, Any], score_lookup: Dict[Tuple[str, str, str], float], ) -> List[str]: ranked = ranked_options_from_scores(context, template, score_lookup) return select_ranked_options( ranked, template, bool(context.get("is_multi", is_multi_task(s))) ) def predict_labels( template_model: Pipeline, selector_model: Pipeline, s: Dict[str, Any] ) -> Tuple[List[str], Dict[str, Any]]: context = build_prediction_context(s) probs = predict_template_probs(template_model, s, context) candidates = [] for tpl, p in probs[:5]: ranked = ranked_options_with_selector(selector_model, s, tpl, context) labs = select_ranked_options(ranked, tpl, context["is_multi"]) if not labs: continue if context["is_multi"] and not (2 <= len(labs) <= 4): continue if (not context["is_multi"]) and len(labs) != 1: continue rank_score = selected_rank_score(labs, ranked) candidates.append((rank_score + 0.05 * p, rank_score, p, tpl, labs)) if candidates: candidates.sort(reverse=True, key=lambda x: x[0]) _, rank_score, p, tpl, labs = candidates[0] else: tpl, p = probs[0] labs = choose_options_with_selector(selector_model, s, tpl, context) rank_score = -1.0 return labs, { "template": tpl, "template_prob": p, "selector_rank_score": rank_score, "top_templates": probs[:5], } def predict_labels_batch( template_model: Pipeline, selector_model: Pipeline, scenarios: List[Dict[str, Any]], top_k_templates: int = 5, ) -> List[Tuple[List[str], Dict[str, Any]]]: if not scenarios: return [] contexts = [build_prediction_context(s) for s in scenarios] scenario_features = [ctx["scenario_features"] for ctx in contexts] template_proba = template_model.predict_proba(scenario_features) template_classes = [str(cls) for cls in template_model.named_steps["clf"].classes_] all_probs: List[List[Tuple[str, float]]] = [] feature_rows: List[Dict[str, float]] = [] feature_keys: List[Tuple[int, str, str, str]] = [] seen_keys = set() for i, (s, ctx, row) in enumerate(zip(scenarios, contexts, template_proba)): probs = [(cls, float(p)) for cls, p in zip(template_classes, row)] probs.sort(key=lambda z: z[1], reverse=True) top_probs = probs[:top_k_templates] all_probs.append(top_probs) opts = ctx["options"] option_actions = ctx["option_actions"] for tpl, _ in top_probs: for cid, label in opts.items(): action = option_actions.get(cid, "other") if action == "other": continue key = (i, tpl, action, cid) if key in seen_keys: continue seen_keys.add(key) feature_keys.append(key) feature_rows.append(option_feature_dict(s, cid, label, tpl, action, ctx)) scores = selector_scores(selector_model, feature_rows) score_lookups: List[Dict[Tuple[str, str, str], float]] = [ {} for _ in scenarios ] for (i, tpl, action, cid), score in zip(feature_keys, scores): score_lookups[i][(tpl, action, cid)] = score out: List[Tuple[List[str], Dict[str, Any]]] = [] for s, ctx, probs, lookup in zip(scenarios, contexts, all_probs, score_lookups): candidates = [] for tpl, p in probs: ranked = ranked_options_from_scores(ctx, tpl, lookup) labs = select_ranked_options(ranked, tpl, bool(ctx["is_multi"])) if not labs: continue if ctx["is_multi"] and not (2 <= len(labs) <= 4): continue if (not ctx["is_multi"]) and len(labs) != 1: continue rank_score = selected_rank_score(labs, ranked) candidates.append((rank_score + 0.05 * p, rank_score, p, tpl, labs)) if candidates: candidates.sort(reverse=True, key=lambda x: x[0]) _, rank_score, p, tpl, labs = candidates[0] else: tpl, p = probs[0] labs = choose_options_from_scores(s, tpl, ctx, lookup) rank_score = -1.0 out.append( ( labs, { "template": tpl, "template_prob": p, "selector_rank_score": rank_score, "top_templates": probs, }, ) ) return out def write_submission(path: str, rows: List[Dict[str, str]]): os.makedirs(os.path.dirname(path) or ".", exist_ok=True) with open(path, "w", encoding="utf-8", newline="") as f: w = csv.DictWriter(f, fieldnames=["ID", "Track A", "Track B"]) w.writeheader() for r in rows: w.writerow(r) def result_csv_path(path: str) -> str: if not path: return "result.csv" directory = os.path.dirname(path) basename = os.path.basename(path) if basename.lower() == "result.csv": return path fixed = os.path.join(directory or ".", "result.csv") print(f"Output filename must be result.csv; writing to {fixed} instead of {path}") return fixed def write_debug(path: str, rows: List[Dict[str, Any]]): if not path: return os.makedirs(os.path.dirname(path) or ".", exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(rows, f, indent=2) def write_json(path: str, rows: List[Dict[str, Any]]): if not path: return os.makedirs(os.path.dirname(path) or ".", exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(rows, f, indent=2, ensure_ascii=False) def save_model_bundle( path: str, template_model: Pipeline, selector_model: Pipeline, metadata: Optional[Dict[str, Any]] = None, ): os.makedirs(os.path.dirname(path) or ".", exist_ok=True) bundle = { "version": MODEL_BUNDLE_VERSION, "template_model": template_model, "selector_model": selector_model, "metadata": metadata or {}, } with open(path, "wb") as f: pickle.dump(bundle, f, protocol=pickle.HIGHEST_PROTOCOL) def load_model_bundle(path: str) -> Tuple[Pipeline, Pipeline, Dict[str, Any]]: with open(path, "rb") as f: bundle = pickle.load(f) if not isinstance(bundle, dict): raise ValueError(f"Unsupported model bundle format: {path}") if bundle.get("version") != MODEL_BUNDLE_VERSION: raise ValueError( f"Unsupported model bundle version {bundle.get('version')} in {path}" ) return ( bundle["template_model"], bundle["selector_model"], dict(bundle.get("metadata") or {}), )