datadb's picture
upload track a submission package
80dcfe9 verified
Raw
History Blame Contribute Delete
90.7 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
src.model_core
Shared model/runtime utilities for Track A.
This module intentionally does not own the CLI. It provides the reusable pieces
used by train.py and main.py:
- JSON/env loading and result/debug/traces writers.
- Server data hydration helpers and trace formatting.
- Question/option parsing and answer normalization.
- Telemetry feature extraction for scenarios and candidate options.
- Prediction-time template scoring, option selection, and model bundle IO.
The modelling approach uses two LightGBM components:
1. A template classifier:
scenario-level telemetry features -> answer action template.
2. An option selector:
scenario, option, target-cell stats, and template/action features ->
whether a candidate option should be selected.
The estimator fitting code belongs in train.py. The executable inference flow
belongs in main.py.
"""
import csv
import io
import json
import math
import os
import pickle
import re
import time
import warnings
from collections import Counter, defaultdict
from typing import Any, Dict, List, Tuple, Optional
warnings.filterwarnings(
"ignore",
message=r"`sklearn\.utils\.parallel\.delayed` should be used.*",
category=UserWarning,
)
import httpx
import numpy as np
import pandas as pd
from openai import OpenAI
from sklearn.pipeline import Pipeline
warnings.filterwarnings("ignore")
MODEL_BUNDLE_VERSION = 1
# =============================================================================
# General utilities
# =============================================================================
def load_json(path: str) -> List[Dict[str, Any]]:
with open(path, "r", encoding="utf-8") as f:
obj = json.load(f)
if isinstance(obj, list):
return obj
if isinstance(obj, dict):
for k in ["data", "records", "scenarios"]:
if isinstance(obj.get(k), list):
return obj[k]
raise ValueError(f"Unsupported JSON structure: {path}")
def norm(x: Any) -> str:
return re.sub(r"\s+", " ", str(x).strip().lower())
def load_env_file(path: Optional[str] = None):
candidates = []
if path:
candidates.append(path)
candidates.extend(
[
os.path.join(os.getcwd(), ".env"),
os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env"),
]
)
for candidate in candidates:
if candidate and os.path.exists(candidate):
try:
from dotenv import load_dotenv
load_dotenv(candidate, override=False)
except Exception:
with open(candidate, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
os.environ.setdefault(key.strip(), value.strip().strip("'\""))
return candidate
return None
def is_placeholder_data(value: Any) -> bool:
return isinstance(value, str) and "use the api" in norm(value)
def scenario_needs_server_data(s: Dict[str, Any]) -> bool:
data = s.get("data", {}) or {}
return any(is_placeholder_data(v) for v in data.values())
def compact_result_for_trace(result: Any, max_value_chars: int = 220) -> Any:
if isinstance(result, str):
if "\n" in result or len(result) > max_value_chars:
return {"chars": len(result), "preview": result[:max_value_chars]}
return result
if isinstance(result, dict):
out = {}
for k, v in result.items():
if isinstance(v, str):
if "\n" in v or len(v) > max_value_chars:
out[k] = {"chars": len(v), "preview": v[:max_value_chars]}
else:
out[k] = v
elif isinstance(v, list):
out[k] = {
"items": len(v),
"preview": compact_result_for_trace(v[:2], max_value_chars),
}
elif isinstance(v, dict):
out[k] = compact_result_for_trace(v, max_value_chars)
else:
out[k] = v
return out
if isinstance(result, list):
return {
"items": len(result),
"preview": [
compact_result_for_trace(item, max_value_chars) for item in result[:2]
],
}
return result
def format_tool_call(name: str, args: Dict[str, Any], result: Any) -> str:
args_text = json.dumps(args, ensure_ascii=False)
result_text = json.dumps(compact_result_for_trace(result), ensure_ascii=False)
return f"Function: {name}, Arguments: {args_text}, Results: {result_text}"
def safe_float(x: Any, default=np.nan) -> float:
try:
if x is None or x == "" or x == "-":
return default
y = float(x)
if math.isnan(y) or math.isinf(y):
return default
return y
except Exception:
return default
def read_pipe_csv(text: Any) -> pd.DataFrame:
if not isinstance(text, str) or not text.strip():
return pd.DataFrame()
try:
return pd.read_csv(io.StringIO(text), sep="|")
except Exception:
try:
return pd.read_csv(io.StringIO(text.replace("\r\n", "\n")), sep="|")
except Exception:
return pd.DataFrame()
def numeric_series(df: pd.DataFrame, col: str) -> pd.Series:
if df is None or df.empty or col not in df.columns:
return pd.Series(dtype=float)
return pd.to_numeric(df[col].replace("-", np.nan), errors="coerce")
def stat_feats(prefix: str, s: pd.Series) -> Dict[str, float]:
s = pd.to_numeric(s, errors="coerce").dropna()
if len(s) == 0:
return {
f"{prefix}_mean": np.nan,
f"{prefix}_min": np.nan,
f"{prefix}_max": np.nan,
f"{prefix}_std": np.nan,
f"{prefix}_p25": np.nan,
f"{prefix}_p50": np.nan,
f"{prefix}_p75": np.nan,
}
return {
f"{prefix}_mean": float(s.mean()),
f"{prefix}_min": float(s.min()),
f"{prefix}_max": float(s.max()),
f"{prefix}_std": float(s.std(ddof=0)),
f"{prefix}_p25": float(np.percentile(s, 25)),
f"{prefix}_p50": float(np.percentile(s, 50)),
f"{prefix}_p75": float(np.percentile(s, 75)),
}
def clean_features(d: Dict[str, Any]) -> Dict[str, float]:
out = {}
for k, v in d.items():
try:
x = float(v)
if math.isnan(x) or math.isinf(x):
x = -999.0
out[k] = x
except Exception:
out[k] = 0.0
return out
def parse_answer(ans: Any) -> List[str]:
if not isinstance(ans, str):
return []
if ans.strip().lower() in {"", "to be determined", "none", "nan"}:
return []
return re.findall(r"C\d+", ans)
def iou_score(pred: List[str], truth: List[str]) -> float:
p, t = set(pred), set(truth)
if not p and not t:
return 1.0
if not p or not t:
return 0.0
return len(p & t) / len(p | t)
def is_multi_task(s: Dict[str, Any]) -> bool:
tag = norm(s.get("tag", ""))
desc = norm(s.get("task", {}).get("description", ""))
return (
"multiple" in tag or "two to four" in desc or "2-4" in desc or "2 to 4" in desc
)
def get_options(s: Dict[str, Any]) -> Dict[str, str]:
opts = s.get("task", {}).get("options", [])
out = {}
if isinstance(opts, list):
for o in opts:
if isinstance(o, dict) and o.get("id"):
out[str(o["id"])] = str(o.get("label", ""))
elif isinstance(opts, dict):
for k, v in opts.items():
if re.match(r"^C\d+$", str(k)):
out[str(k)] = str(v)
return out
def build_question_text(s: Dict[str, Any]) -> str:
task = s.get("task", {}) or {}
parts = [str(task.get("description", "")).strip()]
for cid, label in sorted(get_options(s).items(), key=lambda x: int(x[0][1:])):
parts.append(f"{cid}: {label}")
return "\n".join([p for p in parts if p])
def make_completion(labels: List[str], dbg: Dict[str, Any], s: Dict[str, Any]) -> str:
opts = get_options(s)
scores = {
"template": dbg.get("template"),
"template_prob": round(float(dbg.get("template_prob", 0.0)), 4),
"top_templates": [
[tpl, round(float(prob), 4)]
for tpl, prob in dbg.get("top_templates", [])[:5]
],
}
recs = "\n".join(f"- {cid}: {opts.get(cid, '')}" for cid in labels)
if not recs:
recs = "- No valid option selected"
boxed = "|".join(labels)
return (
"**Evidence:**\n"
f"- Derived model signals: {json.dumps(scores, ensure_ascii=False)}\n\n"
"**Recommendations:**\n"
f"{recs}\n\n"
f"\\boxed{{{boxed}}}"
)
def normalize_prediction_labels(labels: List[str]) -> List[str]:
clean = []
for label in labels:
if isinstance(label, str) and re.match(r"^C\d+$", label.strip()):
clean.append(label.strip())
return sorted(dict.fromkeys(clean), key=lambda x: int(x[1:]))
MODEL_DELEGATION_SYSTEM_PROMPT = """
You are a wireless network optimization assistant.
You must answer by calling the provided trained-model tool. The telemetry and
server data have already been gathered by deterministic code; do not invent or
request telemetry yourself.
Use the trained model prediction as a strong prior, not as an unchangeable
answer. The tool also returns compact option evidence: parsed action, target
cell, change magnitude, current target-cell indicators, selector scores, and
ambiguity groups where options differ only by target, subtype, or magnitude.
Rerank the options when the evidence shows that another candidate better fits
the same troubleshooting action, target cell, threshold subtype, or numerical
change. Be conservative when the evidence is weak.
Final answer rules:
- Return C labels only, such as C5 or C5|C9.
- For multiple labels, sort numerically by C number.
- For a single-answer task, return exactly one label.
- For a multiple-answer task, return two to four labels.
- Do not include prose, markdown, or boxed notation in final_answer.
""".strip()
def get_env_value(names: str) -> str:
for name in re.split(r"[,;]", names):
value = os.environ.get(name.strip(), "").strip()
if value:
return value
return ""
def _time_left(deadline: float) -> float:
if not deadline:
return 10**9
return max(0.0, deadline - time.perf_counter())
def _llm_timeout_left(deadline: float, llm_timeout: float) -> Optional[float]:
if not deadline:
return llm_timeout
remaining = _time_left(deadline)
if remaining <= 0:
return 0.1
return max(0.1, min(llm_timeout, remaining))
def _clip_text(text: str, max_chars: int) -> str:
if max_chars <= 0 or len(text) <= max_chars:
return text
head = max_chars // 2
tail = max_chars - head
return (
text[:head]
+ f"\n... [truncated {len(text) - max_chars} chars] ...\n"
+ text[-tail:]
)
def _json_tool_content(result: Dict[str, Any], max_chars: int) -> str:
text = json.dumps(result, ensure_ascii=False)
return _clip_text(text, max_chars)
def run_agent_with_trained_model(
llm_client: Optional[OpenAI],
model_name: str,
template_model: Pipeline,
selector_model: Pipeline,
scenario: Dict[str, Any],
timeout: float = 60.0,
max_steps: int = 4,
max_tool_calls: int = 8,
temperature: float = 0.0,
max_output_tokens: int = 900,
history_chars: int = 24000,
observation_chars: int = 20000,
question_timeout: float = 180.0,
) -> Tuple[List[str], Dict[str, Any], str, List[str], bool]:
"""Run a bounded Track A agent loop around the trained ML model tool."""
sid = scenario.get("scenario_id") or scenario.get("ID") or "unknown"
started = time.perf_counter()
deadline = started + question_timeout if question_timeout and question_timeout > 0 else 0
model_calls = 0
invalid_json_count = 0
duplicate_skips = 0
timeout_hit = False
seen_tool_calls = set()
def run_ml_tool() -> Dict[str, Any]:
labels, dbg = predict_labels(template_model, selector_model, scenario)
labels = normalize_prediction_labels(labels)
context = build_prediction_context(scenario)
top_templates = dbg.get("top_templates", [])[:5]
evidence = build_option_evidence(
selector_model, scenario, context, top_templates, labels
)
return {
"scenario_id": sid,
"prediction": "|".join(labels),
"labels": labels,
"template": dbg.get("template"),
"template_prob": dbg.get("template_prob"),
"top_templates": top_templates,
"options": get_options(scenario),
"option_evidence": evidence,
}
def build_dbg(
ml_result: Dict[str, Any],
agent_used: bool,
final_raw: str = "",
fallback_reason: str = "",
changed: bool = False,
) -> Dict[str, Any]:
dbg = {
"template": ml_result.get("template"),
"template_prob": ml_result.get("template_prob"),
"top_templates": ml_result.get("top_templates", []),
"agent_used": agent_used,
"agent_changed_prediction": changed,
"model_calls": model_calls,
"invalid_json_count": invalid_json_count,
"duplicate_skips": duplicate_skips,
"timeout_hit": timeout_hit,
}
if final_raw:
dbg["agent_final_raw"] = final_raw
if fallback_reason:
dbg["agent_fallback_reason"] = fallback_reason
return dbg
def labels_from_text(text: str, ml_result: Optional[Dict[str, Any]]) -> List[str]:
parsed: Dict[str, Any] = {}
try:
parsed = json.loads(text)
except Exception:
pass
final_text = str(parsed.get("final_answer") or text)
labels = normalize_prediction_labels(re.findall(r"C\d+", final_text))
if ml_result:
valid_ids = set((ml_result.get("options") or {}).keys())
labels = [label for label in labels if label in valid_ids]
task_type = (
(ml_result.get("option_evidence") or {}).get("task_type")
or ("multiple-answer" if is_multi_task(scenario) else "single-answer")
)
bad_cardinality = (
(task_type == "single-answer" and len(labels) != 1)
or (task_type == "multiple-answer" and not (2 <= len(labels) <= 4))
)
if bad_cardinality:
return []
return labels
if llm_client is None:
result = run_ml_tool()
labels = result["labels"]
dbg = build_dbg(result, False, fallback_reason="agent_disabled")
return labels, dbg, make_completion(labels, dbg, scenario), [
format_tool_call("run_trained_ml_model", {"scenario_id": sid}, result)
], False
tools = [
{
"type": "function",
"function": {
"name": "run_trained_ml_model",
"description": (
"Run the already-trained Track A v4 ML pipeline on the hydrated "
"scenario data. Returns the ML prior plus compact option evidence "
"for agentic reranking."
),
"parameters": {
"type": "object",
"properties": {
"scenario_id": {
"type": "string",
"description": "Scenario ID to score.",
}
},
"required": ["scenario_id"],
"additionalProperties": False,
},
},
}
]
messages = [
{"role": "system", "content": MODEL_DELEGATION_SYSTEM_PROMPT},
{
"role": "user",
"content": (
"Question and options:\n"
f"{build_question_text(scenario)}\n\n"
"Call run_trained_ml_model now. Then rerank only if the returned "
"option evidence shows a better target cell, subtype, or numerical "
"change than the ML prior. Return strict JSON with "
'{"final_answer":"C...","used_tool":true,"reason":"short"}.'
),
},
]
tool_trace: List[str] = []
ml_result: Optional[Dict[str, Any]] = None
final_content = ""
try:
for step in range(1, max(1, max_steps) + 1):
if _time_left(deadline) <= 2:
timeout_hit = True
break
force_tool = ml_result is None
model_calls += 1
request = {
"model": model_name,
"messages": messages,
"tools": tools,
"tool_choice": (
{
"type": "function",
"function": {"name": "run_trained_ml_model"},
}
if force_tool
else "auto"
),
"temperature": temperature,
"max_tokens": max_output_tokens,
"timeout": _llm_timeout_left(deadline, timeout),
}
if not force_tool:
request["response_format"] = {"type": "json_object"}
response = llm_client.chat.completions.create(**request)
msg = response.choices[0].message
tool_calls = getattr(msg, "tool_calls", None) or []
if tool_calls:
messages.append(msg.model_dump(exclude_none=True))
for call in tool_calls:
name = call.function.name
args = json.loads(call.function.arguments or "{}")
call_key = (name, json.dumps(args, sort_keys=True))
if len(tool_trace) >= max_tool_calls:
result = {
"error": "tool_limit_reached",
"message": "Return final answer using existing evidence.",
}
elif call_key in seen_tool_calls:
duplicate_skips += 1
result = {
"duplicate_tool_call_skipped": True,
"message": "Use existing tool evidence and return final JSON.",
}
elif name != "run_trained_ml_model":
result = {"error": f"unsupported tool {name}"}
else:
seen_tool_calls.add(call_key)
result = run_ml_tool()
ml_result = result
tool_trace.append(format_tool_call(name, args, result))
messages.append(
{
"role": "tool",
"tool_call_id": call.id,
"name": name,
"content": _json_tool_content(result, observation_chars),
}
)
if len(tool_trace) >= max_tool_calls:
break
continue
final_content = msg.content or ""
labels = labels_from_text(final_content, ml_result)
if labels:
ml_labels = normalize_prediction_labels(
(ml_result or {}).get("labels", [])
)
changed = bool(ml_result) and labels != ml_labels
if not ml_result:
ml_result = run_ml_tool()
tool_trace.append(
format_tool_call(
"run_trained_ml_model",
{"scenario_id": sid, "fallback": True},
ml_result,
)
)
dbg = build_dbg(
ml_result,
True,
final_raw=final_content,
changed=changed,
)
return labels, dbg, final_content, tool_trace, True
invalid_json_count += 1
messages.append(
{
"role": "assistant",
"content": _clip_text(final_content, history_chars),
}
)
messages.append(
{
"role": "user",
"content": (
"The final answer was invalid. Return strict JSON only, "
'for example {"final_answer":"C5","used_tool":true,"reason":"short"}.'
),
}
)
if not ml_result:
ml_result = run_ml_tool()
tool_trace.append(
format_tool_call(
"run_trained_ml_model",
{"scenario_id": sid, "fallback": True},
ml_result,
)
)
labels = normalize_prediction_labels(ml_result.get("labels", []))
reason = (
"question_timeout"
if timeout_hit
else "max_tool_calls_or_steps_reached"
)
dbg = build_dbg(ml_result, False, final_raw=final_content, fallback_reason=reason)
completion = (
final_content if final_content.strip() else make_completion(labels, dbg, scenario)
)
return labels, dbg, completion, tool_trace, False
except Exception as exc:
result = run_ml_tool()
labels = normalize_prediction_labels(result["labels"])
dbg = build_dbg(
result,
False,
fallback_reason=f"{type(exc).__name__}: {exc}",
)
tool_trace.append(
format_tool_call(
"run_trained_ml_model",
{"scenario_id": sid, "fallback_after_agent_error": True},
result,
)
)
return labels, dbg, make_completion(labels, dbg, scenario), tool_trace, False
def scenario_raw_text(s: Dict[str, Any], max_chars: int = 12000) -> str:
"""Compact raw text feature for selector. Truncated for speed."""
task = s.get("task", {}) or {}
data = s.get("data", {}) or {}
parts = [
task.get("description", ""),
data.get("user_plane_data", "")[:max_chars],
data.get("network_configuration_data", "")[:5000],
data.get("signaling_plane_data", "")[:5000],
data.get("traffic_data", "")[:5000],
data.get("mr_data", "")[:5000],
data.get("notes", ""),
]
return " \n ".join(str(x) for x in parts if x)
# =============================================================================
# Column constants
# =============================================================================
COL_SERV = "5G KPI PCell RF Serving PCI"
COL_RSRP = "5G KPI PCell RF Serving SS-RSRP [dBm]"
COL_SINR = "5G KPI PCell RF Serving SS-SINR [dB]"
COL_THR = "5G KPI PCell Layer2 MAC DL Throughput [Mbps]"
COL_RB = "5G KPI PCell Layer1 DL RB Num (Including 0)"
COL_MCS = "Avg MCS"
COL_RANK = "Avg Rank"
COL_GRANT = "Grant"
COL_CCE_FAIL = "CCE Fail Rate"
COL_IBLER = "Initial BLER(%)"
COL_RBLER = "Residual BLER(%)"
def get_dataframes(
s: Dict[str, Any],
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
data = s.get("data", {}) or {}
user = read_pipe_csv(data.get("user_plane_data"))
config = read_pipe_csv(data.get("network_configuration_data"))
signaling = read_pipe_csv(data.get("signaling_plane_data"))
traffic = read_pipe_csv(data.get("traffic_data"))
mr = read_pipe_csv(data.get("mr_data"))
return user, config, signaling, traffic, mr
def normalize_server_scenario(
local_s: Dict[str, Any], remote_s: Dict[str, Any]
) -> Dict[str, Any]:
merged = dict(local_s)
for key in ["context", "tools", "expected_output", "steps"]:
if remote_s.get(key) is not None:
merged[key] = remote_s[key]
if remote_s.get("data"):
merged["data"] = remote_s["data"]
if remote_s.get("answer") is not None:
merged["answer"] = remote_s["answer"]
return merged
def server_get(
client: httpx.Client,
server_url: str,
endpoint: str,
scenario_id: str,
params: Optional[Dict[str, Any]] = None,
) -> Any:
url = f"{server_url.rstrip('/')}/{endpoint.lstrip('/')}"
headers = {"X-Scenario-Id": scenario_id}
resp = client.get(url, params=params or {}, headers=headers)
resp.raise_for_status()
return resp.json()
def hydrate_scenario_from_server(
s: Dict[str, Any],
client: httpx.Client,
server_url: str,
try_scenario_endpoint: bool = False,
) -> Tuple[Dict[str, Any], List[str]]:
sid = s.get("scenario_id") or s.get("ID")
if not sid:
raise ValueError("Cannot hydrate scenario without scenario_id/ID")
tool_calls: List[str] = []
if try_scenario_endpoint:
try:
remote = server_get(client, server_url, "/scenario", sid)
tool_calls.append(
format_tool_call("get_scenario", {"scenario_id": sid}, remote)
)
hydrated = normalize_server_scenario(s, remote)
if not scenario_needs_server_data(hydrated):
return hydrated, tool_calls
except Exception as exc:
tool_calls.append(
format_tool_call(
"get_scenario",
{"scenario_id": sid},
{"error": f"{type(exc).__name__}: {exc}"},
)
)
data = dict((s.get("data") or {}))
endpoint_map = [
(
"get_user_plane_data",
"/user-plane-data",
"User Plane Data",
"user_plane_data",
),
(
"get_config_data",
"/config-data",
"Network Configuration Data",
"network_configuration_data",
),
("get_kpi_data", "/get_kpi_data", "Traffic Data", "traffic_data"),
("get_mr_data", "/get_mr_data", "MR Data", "mr_data"),
]
for func_name, endpoint, response_key, data_key in endpoint_map:
try:
result = server_get(client, server_url, endpoint, sid)
tool_calls.append(format_tool_call(func_name, {}, result))
if response_key in result:
data[data_key] = result[response_key]
except Exception as exc:
tool_calls.append(
format_tool_call(
func_name,
{},
{"error": f"{type(exc).__name__}: {exc}"},
)
)
if is_placeholder_data(
data.get("signaling_plane_data")
) and not is_placeholder_data(data.get("user_plane_data")):
user_df = read_pipe_csv(data.get("user_plane_data"))
if not user_df.empty and "Timestamp" in user_df.columns:
time_arg = str(user_df["Timestamp"].iloc[-1])
try:
result = server_get(
client,
server_url,
"/signaling-plane-event-log",
sid,
{"time": time_arg},
)
tool_calls.append(
format_tool_call(
"get_signaling_plane_event_log",
{"time": time_arg},
result,
)
)
if isinstance(result, str):
data["signaling_plane_data"] = result
except Exception as exc:
tool_calls.append(
format_tool_call(
"get_signaling_plane_event_log",
{"time": time_arg},
{"error": f"{type(exc).__name__}: {exc}"},
)
)
hydrated = dict(s)
hydrated["data"] = data
return hydrated, tool_calls
def neighbor_col_pairs(user: pd.DataFrame) -> List[Tuple[str, str]]:
pairs = []
if user is None or user.empty:
return pairs
for i in range(1, 6):
pci = f"Measurement PCell Neighbor Cell Top Set(Cell Level) Top {i} PCI"
rsrp = f"Measurement PCell Neighbor Cell Top Set(Cell Level) Top {i} Filtered Tx BRSRP [dBm]"
if pci in user.columns and rsrp in user.columns:
pairs.append((pci, rsrp))
return pairs
def bad_rows(user: pd.DataFrame) -> pd.DataFrame:
if user is None or user.empty:
return pd.DataFrame()
if COL_THR not in user.columns:
return user.copy()
u = user.copy()
u["_thr"] = numeric_series(u, COL_THR)
thr = u["_thr"]
if not thr.notna().any():
return u.copy()
q35 = float(np.nanpercentile(thr.dropna(), 35))
b = u[(thr < 100) | (thr <= q35)].copy()
if b.empty:
b = u.nsmallest(min(3, len(u)), "_thr").copy()
return b
def build_cell_maps(config: pd.DataFrame) -> Tuple[Dict[str, int], Dict[int, str]]:
g2p, p2g = {}, {}
if config is None or config.empty:
return g2p, p2g
for _, r in config.iterrows():
try:
key = f"{int(float(r.get('gNodeB ID')))}_{int(float(r.get('Cell ID')))}"
pci = int(float(r.get("PCI")))
g2p[key] = pci
p2g[pci] = key
except Exception:
continue
return g2p, p2g
def haversine_m(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
radius = 6371000.0
dlat = math.radians(lat2 - lat1)
dlon = math.radians(lon2 - lon1)
a = (
math.sin(dlat / 2) ** 2
+ math.cos(math.radians(lat1))
* math.cos(math.radians(lat2))
* math.sin(dlon / 2) ** 2
)
return 2 * radius * math.asin(math.sqrt(a))
def bearing_deg(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
lat1_rad = math.radians(lat1)
lat2_rad = math.radians(lat2)
dlon = math.radians(lon2 - lon1)
x = math.sin(dlon) * math.cos(lat2_rad)
y = math.cos(lat1_rad) * math.sin(lat2_rad) - math.sin(lat1_rad) * math.cos(
lat2_rad
) * math.cos(dlon)
return (math.degrees(math.atan2(x, y)) + 360) % 360
def angle_delta_deg(a: float, b: float) -> float:
delta = abs(a - b) % 360
return 360 - delta if delta > 180 else delta
# =============================================================================
# Action parsing and templates
# =============================================================================
ACTION_ORDER = [
"server",
"insufficient",
"pdcch",
"inc_power",
"dec_power",
"tilt_down",
"tilt_up",
"azimuth",
"inc_a3",
"dec_a3",
"a2a5",
"neighbor",
]
def label_to_action(label: str) -> str:
t = norm(label)
if "insufficient data" in t or "more data" in t:
return "insufficient"
if "server" in t or "transmission issue" in t or "transmission issues" in t:
return "server"
if "pdcch" in t or "2sym" in t:
return "pdcch"
if "increase transmission power" in t:
return "inc_power"
if "decrease transmission power" in t:
return "dec_power"
if "press down" in t or "down the tilt" in t:
return "tilt_down"
if "lift" in t and "tilt" in t:
return "tilt_up"
if "azimuth" in t:
return "azimuth"
if "increase a3" in t:
return "inc_a3"
if "decrease a3" in t:
return "dec_a3"
if (
"covinterfreqa2" in t
or "covinterfreqa5" in t
or "a2rsrpthld" in t
or "a5rsrpthld" in t
):
return "a2a5"
if "neighbor relationship" in t:
return "neighbor"
return "other"
def answer_to_template(s: Dict[str, Any]) -> str:
opts = get_options(s)
ids = parse_answer(s.get("answer", ""))
actions = [label_to_action(opts.get(cid, "")) for cid in ids]
counts = Counter([a for a in actions if a and a != "other"])
parts = []
for a in ACTION_ORDER:
parts.extend([a] * counts.get(a, 0))
return "|".join(parts) if parts else "other"
def template_to_actions(template: str) -> List[str]:
if not template or template == "other":
return []
return [x for x in template.split("|") if x]
def extract_gnb_cell(label: str) -> Optional[str]:
m = re.search(r"\b(\d{5,8}_\d{1,3})\b", str(label))
return m.group(1) if m else None
def extract_all_gnb_cells(label: str) -> List[str]:
return re.findall(r"\b\d{5,8}_\d{1,3}\b", str(label))
def parse_pdcch_symbols(value: Any) -> float:
m = re.search(r"(\d+)\s*sym", norm(value))
if not m:
return np.nan
return safe_float(m.group(1), np.nan)
def parse_option_semantics(label: str) -> Dict[str, Any]:
"""Extract option facts without deciding whether the option is correct."""
text = str(label or "")
t = norm(text)
masked = re.sub(r"\b\d{5,8}_\d{1,3}\b", "<cell>", t)
action = label_to_action(text)
cells = extract_all_gnb_cells(text)
amount = np.nan
unit = ""
threshold_types: List[str] = []
pdcch_symbols = np.nan
if action == "pdcch":
pdcch_symbols = parse_pdcch_symbols(text)
amount = pdcch_symbols
unit = "sym"
elif action in {"inc_power", "dec_power"}:
m = re.search(r"\bby\s*([-+]?\d+(?:\.\d+)?)\s*dbm\b", masked)
if m:
amount = safe_float(m.group(1), np.nan)
unit = "dbm"
elif action in {"tilt_down", "tilt_up", "azimuth"}:
m = re.search(r"\bby\s*([-+]?\d+(?:\.\d+)?)\s*degrees?\b", masked)
if m:
amount = safe_float(m.group(1), np.nan)
unit = "degrees"
elif action in {"inc_a3", "dec_a3"}:
m = re.search(r"\bby\s*([-+]?\d+(?:\.\d+)?)\s*db\b", masked)
if m:
amount = safe_float(m.group(1), np.nan)
unit = "db"
elif action == "a2a5":
if "covinterfreqa2rsrpthld" in t:
threshold_types.append("a2")
if "covinterfreqa5rsrpthld1" in t:
threshold_types.append("a5_1")
if "covinterfreqa5rsrpthld2" in t:
threshold_types.append("a5_2")
m = re.search(r"\bby\s*([-+]?\d+(?:\.\d+)?)\s*db\b", masked)
if m:
amount = safe_float(m.group(1), np.nan)
unit = "db"
return {
"action": action,
"target_cell": cells[0] if cells else None,
"cells": cells,
"amount": amount,
"unit": unit,
"threshold_types": threshold_types,
"threshold_type": "|".join(threshold_types),
"pdcch_symbols": pdcch_symbols,
"towards_ue": 1.0 if "towards the ue" in t else 0.0,
}
def extract_cell_geometry_stats(s: Dict[str, Any]) -> Dict[str, Dict[str, float]]:
user, config, _, _, _ = get_dataframes(s)
if user is None or user.empty or config is None or config.empty:
return {}
need_user = {"Longitude", "Latitude", COL_SERV}
need_cfg = {
"gNodeB ID",
"Cell ID",
"PCI",
"Longitude",
"Latitude",
"Mechanical Azimuth",
"Mechanical Downtilt",
"Height",
}
if not need_user.issubset(user.columns) or not need_cfg.issubset(config.columns):
return {}
_, p2g = build_cell_maps(config)
cfg_by_key: Dict[str, Dict[str, float]] = {}
for _, r in config.iterrows():
try:
key = f"{int(float(r.get('gNodeB ID')))}_{int(float(r.get('Cell ID')))}"
cfg_by_key[key] = {
"lon": safe_float(r.get("Longitude"), np.nan),
"lat": safe_float(r.get("Latitude"), np.nan),
"azimuth": safe_float(r.get("Mechanical Azimuth"), np.nan),
"downtilt": safe_float(r.get("Mechanical Downtilt"), np.nan),
"height": safe_float(r.get("Height"), np.nan),
}
except Exception:
continue
related_rows: Dict[str, List[pd.Series]] = defaultdict(list)
b = bad_rows(user)
for _, r in b.iterrows():
spci = safe_float(r.get(COL_SERV), np.nan)
if not np.isnan(spci):
key = p2g.get(int(spci))
if key:
related_rows[key].append(r)
for pci_c, _ in neighbor_col_pairs(user):
npci = safe_float(r.get(pci_c), np.nan)
if np.isnan(npci):
continue
key = p2g.get(int(npci))
if key:
related_rows[key].append(r)
out: Dict[str, Dict[str, float]] = {}
for key, cfg in cfg_by_key.items():
rows = related_rows.get(key, [])
if not rows:
continue
h_errors, desired_tilts, tilt_errors, distances = [], [], [], []
for r in rows:
try:
ulon = safe_float(r.get("Longitude"), np.nan)
ulat = safe_float(r.get("Latitude"), np.nan)
if np.isnan(ulon) or np.isnan(ulat):
continue
dist = haversine_m(cfg["lon"], cfg["lat"], ulon, ulat)
brg = bearing_deg(cfg["lon"], cfg["lat"], ulon, ulat)
hdiff = angle_delta_deg(brg, cfg["azimuth"])
desired_tilt = math.degrees(math.atan2(cfg["height"], max(dist, 0.1)))
h_errors.append(hdiff)
desired_tilts.append(desired_tilt)
tilt_errors.append(abs(cfg["downtilt"] - desired_tilt))
distances.append(dist)
except Exception:
continue
if h_errors:
out[key] = clean_features(
{
"geom_related_bad_count": float(len(h_errors)),
"geom_distance_mean": float(np.mean(distances)),
"geom_azimuth_error_mean": float(np.mean(h_errors)),
"geom_azimuth_error_p50": float(np.percentile(h_errors, 50)),
"geom_azimuth_error_max": float(np.max(h_errors)),
"geom_desired_downtilt_mean": float(np.mean(desired_tilts)),
"geom_desired_downtilt_p50": float(np.percentile(desired_tilts, 50)),
"geom_current_tilt_error_mean": float(np.mean(tilt_errors)),
"geom_current_tilt_error_p50": float(np.percentile(tilt_errors, 50)),
}
)
return out
# =============================================================================
# Scenario-level features for template classifier
# =============================================================================
def estimate_mainlobe_flags(user: pd.DataFrame, config: pd.DataFrame) -> pd.Series:
if user is None or user.empty or config is None or config.empty:
return pd.Series(dtype=float)
need_user = {"Longitude", "Latitude", COL_SERV}
need_cfg = {
"PCI",
"Longitude",
"Latitude",
"Mechanical Azimuth",
"Mechanical Downtilt",
"Height",
}
if not need_user.issubset(user.columns) or not need_cfg.issubset(config.columns):
return pd.Series([np.nan] * len(user), index=user.index)
cfg = config.copy()
cfg["PCI"] = pd.to_numeric(cfg["PCI"], errors="coerce")
cfg = cfg.set_index("PCI", drop=False)
flags = []
for _, r in user.iterrows():
try:
pci = int(float(r[COL_SERV]))
if pci not in cfg.index:
flags.append(np.nan)
continue
c = cfg.loc[pci]
ulon, ulat = float(r["Longitude"]), float(r["Latitude"])
clon, clat = float(c["Longitude"]), float(c["Latitude"])
az = float(c["Mechanical Azimuth"])
downtilt = float(c["Mechanical Downtilt"])
height = float(c["Height"])
lat1 = math.radians(clat)
lat2 = math.radians(ulat)
dlon = math.radians(ulon - clon)
x = math.sin(dlon) * math.cos(lat2)
y = math.cos(lat1) * math.sin(lat2) - math.sin(lat1) * math.cos(
lat2
) * math.cos(dlon)
bearing = (math.degrees(math.atan2(x, y)) + 360) % 360
hdiff = abs(bearing - az) % 360
if hdiff > 180:
hdiff = 360 - hdiff
R = 6371000.0
dlat = math.radians(ulat - clat)
dlon2 = math.radians(ulon - clon)
a = (
math.sin(dlat / 2) ** 2
+ math.cos(math.radians(clat))
* math.cos(math.radians(ulat))
* math.sin(dlon2 / 2) ** 2
)
dist = 2 * R * math.asin(math.sqrt(a))
tilt_angle = math.degrees(math.atan2(-height, max(dist, 0.1)))
final_tilt = tilt_angle - downtilt
flags.append(1.0 if (abs(hdiff) > 50 or abs(final_tilt) > 6) else 0.0)
except Exception:
flags.append(np.nan)
return pd.Series(flags, index=user.index)
def extract_scenario_features(s: Dict[str, Any]) -> Dict[str, float]:
user, config, signaling, traffic, mr = get_dataframes(s)
feats: Dict[str, float] = {}
feats["is_multi"] = float(is_multi_task(s))
ctx = s.get("context", {}).get("wireless_network_information", {}) or {}
feats["num_base_stations"] = safe_float(ctx.get("num_base_stations"), np.nan)
if user is None or user.empty:
feats["missing_user"] = 1.0
return clean_features(feats)
feats["missing_user"] = 0.0
feats["n_user_rows"] = float(len(user))
for prefix, ser in {
"thr": numeric_series(user, COL_THR),
"rsrp": numeric_series(user, COL_RSRP),
"sinr": numeric_series(user, COL_SINR),
"rb": numeric_series(user, COL_RB),
"mcs": numeric_series(user, COL_MCS),
"rank": numeric_series(user, COL_RANK),
"grant": numeric_series(user, COL_GRANT),
"ccefail": numeric_series(user, COL_CCE_FAIL),
"ibler": numeric_series(user, COL_IBLER),
"rbler": numeric_series(user, COL_RBLER),
}.items():
feats.update(stat_feats(prefix, ser))
b = bad_rows(user)
b_rsrp = numeric_series(b, COL_RSRP)
b_sinr = numeric_series(b, COL_SINR)
b_mcs = numeric_series(b, COL_MCS)
b_ibler = numeric_series(b, COL_IBLER)
feats["bad_frac"] = float(len(b) / max(1, len(user)))
for prefix, ser in {
"bad_thr": numeric_series(b, COL_THR),
"bad_rsrp": b_rsrp,
"bad_sinr": b_sinr,
"bad_rb": numeric_series(b, COL_RB),
"bad_mcs": b_mcs,
"bad_rank": numeric_series(b, COL_RANK),
"bad_ccefail": numeric_series(b, COL_CCE_FAIL),
"bad_ibler": b_ibler,
"bad_rbler": numeric_series(b, COL_RBLER),
}.items():
feats.update(stat_feats(prefix, ser))
feats["bad_rsrp_lt_m95_frac"] = (
float((b_rsrp < -95).mean()) if b_rsrp.notna().any() else np.nan
)
feats["bad_rsrp_lt_m90_frac"] = (
float((b_rsrp < -90).mean()) if b_rsrp.notna().any() else np.nan
)
feats["bad_sinr_lt_0_frac"] = (
float((b_sinr < 0).mean()) if b_sinr.notna().any() else np.nan
)
feats["bad_sinr_lt_5_frac"] = (
float((b_sinr < 5).mean()) if b_sinr.notna().any() else np.nan
)
feats["good_rsrp_bad_sinr_frac"] = (
float(((b_rsrp > -90) & (b_sinr < 5)).mean()) if len(b) else np.nan
)
feats["weak_rsrp_good_sinr_frac"] = (
float(((b_rsrp < -95) & (b_sinr > 8)).mean()) if len(b) else np.nan
)
feats["low_mcs_frac"] = (
float((b_mcs < 10).mean()) if b_mcs.notna().any() else np.nan
)
feats["high_bler_frac"] = (
float((b_ibler > 15).mean()) if b_ibler.notna().any() else np.nan
)
if COL_SERV in user.columns:
serv_all = pd.to_numeric(user[COL_SERV], errors="coerce")
serv_bad = pd.to_numeric(
b.get(COL_SERV, pd.Series(dtype=float)), errors="coerce"
)
feats["n_serving_pci_all"] = float(serv_all.nunique(dropna=True))
feats["n_serving_pci_bad"] = float(serv_bad.nunique(dropna=True))
feats["dominant_bad_serv_frac"] = (
float(serv_bad.value_counts(normalize=True).iloc[0])
if serv_bad.notna().any()
else np.nan
)
stronger_counts, margins, neighbor_counts = [], [], []
for _, r in b.iterrows():
srv = safe_float(r.get(COL_RSRP))
row_stronger, row_count, best = 0, 0, -999.0
for pci_c, rsrp_c in neighbor_col_pairs(user):
nrsrp = safe_float(r.get(rsrp_c))
if np.isnan(nrsrp):
continue
row_count += 1
best = max(best, nrsrp)
if not np.isnan(srv) and nrsrp > srv + 2:
row_stronger += 1
stronger_counts.append(row_stronger)
neighbor_counts.append(row_count)
margins.append(best - srv if row_count and not np.isnan(srv) else np.nan)
feats["bad_neighbor_stronger_mean"] = (
float(np.nanmean(stronger_counts)) if stronger_counts else np.nan
)
feats["bad_neighbor_count_mean"] = (
float(np.nanmean(neighbor_counts)) if neighbor_counts else np.nan
)
feats["bad_best_neighbor_margin_mean"] = (
float(np.nanmean(margins)) if margins else np.nan
)
feats["bad_best_neighbor_margin_max"] = (
float(np.nanmax(margins)) if margins else np.nan
)
outside = estimate_mainlobe_flags(user, config)
if len(outside):
feats["outside_mainlobe_frac_all"] = float(
pd.to_numeric(outside, errors="coerce").mean()
)
feats["outside_mainlobe_frac_bad"] = (
float(pd.to_numeric(outside.loc[b.index], errors="coerce").mean())
if len(b)
else np.nan
)
if config is not None and not config.empty:
feats["n_config_cells"] = float(len(config))
for col in [
"Transmission Power",
"Mechanical Downtilt",
"Digital Tilt",
"Height",
"IntraFreqHoA3Offset [0.5dB]",
]:
if col in config.columns:
key = "cfg_" + re.sub(r"[^a-zA-Z0-9]+", "_", col).strip("_").lower()
feats.update(stat_feats(key, numeric_series(config, col)))
if "PdcchOccupiedSymbolNum" in config.columns:
feats["cfg_pdcch_1sym_frac"] = float(
config["PdcchOccupiedSymbolNum"]
.astype(str)
.str.lower()
.str.contains("1sym")
.mean()
)
if signaling is not None and not signaling.empty:
event_names = " ".join(
"" if pd.isna(x) else str(x)
for x in signaling.get("Event Name", pd.Series(dtype=object)).tolist()
).lower()
event_content = " ".join(
"" if pd.isna(x) else str(x)
for x in signaling.get("Event Content", pd.Series(dtype=object)).tolist()
).lower()
text = event_names + " " + event_content
feats["sig_rows"] = float(len(signaling))
feats["sig_a3_count"] = float(text.count("eventa3"))
feats["sig_a2_count"] = float(text.count("eventa2"))
feats["sig_ho_attempt_count"] = float(text.count("handoverattempt"))
feats["sig_reest_count"] = float(text.count("reestablish"))
feats["sig_ra_attempt_count"] = float(text.count("randomaccessattempt"))
feats["sig_ra_success_count"] = float(text.count("randomaccesssuc"))
sig_margins = []
for content in signaling.get("Event Content", pd.Series(dtype=object)).tolist():
txt = "" if pd.isna(content) else str(content)
sm = re.search(r"ServCellRSRP:?\s*(-?\d+)", txt)
nm = re.search(r"NCellRSRP:?\s*(-?\d+)", txt)
if sm and nm:
sig_margins.append(float(nm.group(1)) - float(sm.group(1)))
feats["sig_a3_margin_mean"] = (
float(np.mean(sig_margins)) if sig_margins else np.nan
)
feats["sig_a3_margin_max"] = (
float(np.max(sig_margins)) if sig_margins else np.nan
)
else:
feats["sig_rows"] = 0.0
if traffic is not None and not traffic.empty:
feats["traffic_rows"] = float(len(traffic))
for col in [
"Uplink PRB utilization(%)",
"Downlink PRB utilization(%)",
"Uplink PRB Interference(dBm)",
"User Uplink Throughput(Mbps)",
"User Downlink Throughput(Mbps)",
"Downlink Weak Coversge Ratio",
"TA>1KM Ratio",
"Uplink CCE utilization(%)",
"Downlink CCE utilization(%)",
"Uplink CCE Allocation Success Rate(%)",
"Downlink CCE Allocation Success Rate(%)",
]:
if col in traffic.columns:
key = "traffic_" + re.sub(r"[^a-zA-Z0-9]+", "_", col).strip("_").lower()
feats.update(stat_feats(key, numeric_series(traffic, col)))
else:
feats["traffic_rows"] = 0.0
if mr is not None and not mr.empty:
feats["mr_rows"] = float(len(mr))
for col in [
"Serving RSRP(dBm)",
"Throughput(Mbps)",
"Neighbor 1 RSRP(dBm)",
"Neighbor 2 RSRP(dBm)",
"Neighbor 3 RSRP(dBm)",
]:
if col in mr.columns:
key = "mr_" + re.sub(r"[^a-zA-Z0-9]+", "_", col).strip("_").lower()
feats.update(stat_feats(key, numeric_series(mr, col)))
neigh_margins = []
for _, r in mr.iterrows():
sr = safe_float(r.get("Serving RSRP(dBm)"))
best, found = -999.0, False
for i in range(1, 4):
nr = safe_float(r.get(f"Neighbor {i} RSRP(dBm)"))
if not np.isnan(nr):
best = max(best, nr)
found = True
if found and not np.isnan(sr):
neigh_margins.append(best - sr)
feats["mr_best_neighbor_margin_mean"] = (
float(np.mean(neigh_margins)) if neigh_margins else np.nan
)
feats["mr_neighbor_stronger_frac"] = (
float(np.mean([m > 2 for m in neigh_margins])) if neigh_margins else np.nan
)
else:
feats["mr_rows"] = 0.0
return clean_features(feats)
# =============================================================================
# Generic target-cell features for selector model
# =============================================================================
def extract_cell_stats(s: Dict[str, Any]) -> Dict[str, Dict[str, float]]:
user, config, signaling, traffic, mr = get_dataframes(s)
g2p, p2g = build_cell_maps(config)
stats = defaultdict(lambda: defaultdict(float))
if user is not None and not user.empty:
b = bad_rows(user)
for _, r in b.iterrows():
spci = safe_float(r.get(COL_SERV))
srv_rsrp = safe_float(r.get(COL_RSRP))
if not np.isnan(spci):
key = p2g.get(int(spci))
if key:
stats[key]["bad_serv_count"] += 1.0
stats[key]["bad_serv_rsrp_sum"] += safe_float(r.get(COL_RSRP), 0.0)
stats[key]["bad_serv_sinr_sum"] += safe_float(r.get(COL_SINR), 0.0)
stats[key]["bad_serv_thr_sum"] += safe_float(r.get(COL_THR), 0.0)
stats[key]["bad_serv_mcs_sum"] += safe_float(r.get(COL_MCS), 0.0)
stats[key]["bad_serv_bler_sum"] += safe_float(r.get(COL_IBLER), 0.0)
for pci_c, rsrp_c in neighbor_col_pairs(user):
npci = safe_float(r.get(pci_c))
nrsrp = safe_float(r.get(rsrp_c))
if np.isnan(npci):
continue
key = p2g.get(int(npci))
if key:
stats[key]["bad_nei_count"] += 1.0
if not np.isnan(nrsrp):
stats[key]["bad_nei_rsrp_sum"] += nrsrp
if not np.isnan(srv_rsrp):
stats[key]["bad_nei_margin_sum"] += nrsrp - srv_rsrp
if nrsrp > srv_rsrp + 2:
stats[key]["bad_nei_stronger_count"] += 1.0
n_bad = max(1, len(b))
for key in list(stats.keys()):
stats[key]["bad_serv_frac"] = stats[key].get("bad_serv_count", 0.0) / n_bad
stats[key]["bad_nei_frac"] = stats[key].get("bad_nei_count", 0.0) / n_bad
if config is not None and not config.empty:
for _, r in config.iterrows():
try:
key = f"{int(float(r.get('gNodeB ID')))}_{int(float(r.get('Cell ID')))}"
stats[key]["cfg_power"] = safe_float(
r.get("Transmission Power"), np.nan
)
stats[key]["cfg_downtilt"] = safe_float(
r.get("Mechanical Downtilt"), np.nan
)
stats[key]["cfg_digital_tilt"] = safe_float(
r.get("Digital Tilt"), np.nan
)
stats[key]["cfg_height"] = safe_float(r.get("Height"), np.nan)
stats[key]["cfg_a3_offset"] = safe_float(
r.get("IntraFreqHoA3Offset [0.5dB]"), np.nan
)
stats[key]["cfg_a2_threshold"] = safe_float(
r.get("CovInterFreqA2RsrpThld [dBm]"), np.nan
)
stats[key]["cfg_a5_threshold_1"] = safe_float(
r.get("CovInterFreqA5RsrpThld1 [dBm]"), np.nan
)
stats[key]["cfg_a5_threshold_2"] = safe_float(
r.get("CovInterFreqA5RsrpThld2 [dBm]"), np.nan
)
stats[key]["cfg_max_power"] = safe_float(
r.get("Max Transmit Power"), np.nan
)
stats[key]["cfg_pdcch_symbols"] = parse_pdcch_symbols(
r.get("PdcchOccupiedSymbolNum", "")
)
stats[key]["cfg_pdcch_1sym"] = (
1.0 if "1sym" in norm(r.get("PdcchOccupiedSymbolNum", "")) else 0.0
)
except Exception:
continue
if traffic is not None and not traffic.empty:
for _, r in traffic.iterrows():
try:
key = f"{int(float(r.get('gNodeB_ID')))}_{int(float(r.get('Cell_ID')))}"
stats[key]["traffic_dl_prb"] = safe_float(
r.get("Downlink PRB utilization(%)"), np.nan
)
stats[key]["traffic_ul_prb"] = safe_float(
r.get("Uplink PRB utilization(%)"), np.nan
)
stats[key]["traffic_ul_interf"] = safe_float(
r.get("Uplink PRB Interference(dBm)"), np.nan
)
stats[key]["traffic_dl_tp"] = safe_float(
r.get("User Downlink Throughput(Mbps)"), np.nan
)
stats[key]["traffic_weak_cov"] = safe_float(
r.get("Downlink Weak Coversge Ratio"), np.nan
)
stats[key]["traffic_ta_gt_1km"] = safe_float(
r.get("TA>1KM Ratio"), np.nan
)
stats[key]["traffic_cce_dl"] = safe_float(
r.get("Downlink CCE utilization(%)"), np.nan
)
stats[key]["traffic_cce_succ_dl"] = safe_float(
r.get("Downlink CCE Allocation Success Rate(%)"), np.nan
)
except Exception:
continue
if mr is not None and not mr.empty:
for _, r in mr.iterrows():
spci = safe_float(r.get("Serving PCI"))
if not np.isnan(spci):
key = p2g.get(int(spci))
if key:
stats[key]["mr_serv_count"] += 1.0
stats[key]["mr_serv_rsrp_sum"] += safe_float(
r.get("Serving RSRP(dBm)"), 0.0
)
stats[key]["mr_thr_sum"] += safe_float(
r.get("Throughput(Mbps)"), 0.0
)
for i in range(1, 4):
npci = safe_float(r.get(f"Neighbor {i} PCI"))
if np.isnan(npci):
continue
key = p2g.get(int(npci))
if key:
stats[key]["mr_nei_count"] += 1.0
stats[key]["mr_nei_rsrp_sum"] += safe_float(
r.get(f"Neighbor {i} RSRP(dBm)"), 0.0
)
out = {}
for key, d in stats.items():
dd = dict(d)
n = dd.get("bad_serv_count", 0.0)
if n:
dd["bad_serv_rsrp_mean"] = dd["bad_serv_rsrp_sum"] / n
dd["bad_serv_sinr_mean"] = dd["bad_serv_sinr_sum"] / n
dd["bad_serv_thr_mean"] = dd["bad_serv_thr_sum"] / n
dd["bad_serv_mcs_mean"] = dd["bad_serv_mcs_sum"] / n
dd["bad_serv_bler_mean"] = dd["bad_serv_bler_sum"] / n
n = dd.get("bad_nei_count", 0.0)
if n:
dd["bad_nei_rsrp_mean"] = dd["bad_nei_rsrp_sum"] / n
dd["bad_nei_margin_mean"] = dd["bad_nei_margin_sum"] / n
n = dd.get("mr_serv_count", 0.0)
if n:
dd["mr_serv_rsrp_mean"] = dd["mr_serv_rsrp_sum"] / n
dd["mr_thr_mean"] = dd["mr_thr_sum"] / n
n = dd.get("mr_nei_count", 0.0)
if n:
dd["mr_nei_rsrp_mean"] = dd["mr_nei_rsrp_sum"] / n
out[key] = clean_features(dd)
return out
def option_feature_dict(
s: Dict[str, Any],
cid: str,
label: str,
template: str,
action: str,
context: Optional[Dict[str, Any]] = None,
) -> Dict[str, float]:
feats = {}
context = context or {}
scen = context.get("scenario_features")
if scen is None:
scen = extract_scenario_features(s)
# Prefix scenario features to share with selector.
for k, v in scen.items():
feats[f"scen_{k}"] = v
is_multi = context.get("is_multi")
if is_multi is None:
is_multi = is_multi_task(s)
feats["is_multi"] = float(is_multi)
feats[f"action_{action}"] = 1.0
feats[f"template_{template}"] = 1.0
template_actions = template_to_actions(template)
template_action_counts = Counter(template_actions)
feats["action_in_template"] = float(action in template_action_counts)
feats["template_action_count"] = float(len(template_actions))
feats["template_this_action_count"] = float(template_action_counts.get(action, 0))
semantics = (context.get("option_semantics") or {}).get(cid)
if semantics is None:
semantics = parse_option_semantics(label)
amount = safe_float(semantics.get("amount"), np.nan)
feats["option_has_amount"] = float(not np.isnan(amount))
feats["option_amount"] = amount
feats["option_amount_abs"] = abs(amount) if not np.isnan(amount) else np.nan
feats["option_towards_ue"] = safe_float(semantics.get("towards_ue"), 0.0)
pdcch_symbols = safe_float(semantics.get("pdcch_symbols"), np.nan)
feats["option_pdcch_symbols"] = pdcch_symbols
threshold_types = set(semantics.get("threshold_types") or [])
feats["option_threshold_a2"] = float("a2" in threshold_types)
feats["option_threshold_a5_1"] = float("a5_1" in threshold_types)
feats["option_threshold_a5_2"] = float("a5_2" in threshold_types)
feats["option_threshold_combined"] = float(len(threshold_types) > 1)
target = extract_gnb_cell(label)
targets = extract_all_gnb_cells(label)
feats["has_target_cell"] = float(target is not None)
feats["num_cells_in_label"] = float(len(targets))
roles = context.get("cell_stats")
if roles is None:
roles = extract_cell_stats(s)
geom_roles = context.get("cell_geometry_stats")
if geom_roles is None:
geom_roles = extract_cell_geometry_stats(s)
if target and target in roles:
for k, v in roles[target].items():
feats[f"target_{k}"] = v
for k, v in geom_roles.get(target, {}).items():
feats[f"target_{k}"] = v
target_stats = roles[target]
geom_stats = geom_roles.get(target, {})
serv_rsrp = safe_float(target_stats.get("bad_serv_rsrp_mean"), np.nan)
nei_margin = safe_float(target_stats.get("bad_nei_margin_mean"), np.nan)
ul_interf = safe_float(target_stats.get("traffic_ul_interf"), np.nan)
feats["target_rsrp_deficit_to_m95"] = (
-95.0 - serv_rsrp if not np.isnan(serv_rsrp) else np.nan
)
feats["target_neighbor_excess_margin"] = (
nei_margin - 2.0 if not np.isnan(nei_margin) else np.nan
)
feats["target_ul_interference_over_m115"] = (
ul_interf + 115.0 if not np.isnan(ul_interf) else np.nan
)
if not np.isnan(amount):
if action in {"inc_power", "dec_power"}:
current = safe_float(target_stats.get("cfg_power"), np.nan)
delta = amount if action == "inc_power" else -amount
proposed = current + delta if not np.isnan(current) else np.nan
feats["target_proposed_power"] = proposed
max_power = safe_float(target_stats.get("cfg_max_power"), np.nan)
feats["target_power_margin_after"] = (
max_power - proposed
if not np.isnan(max_power) and not np.isnan(proposed)
else np.nan
)
elif action in {"tilt_down", "tilt_up"}:
current = safe_float(target_stats.get("cfg_downtilt"), np.nan)
delta = amount if action == "tilt_down" else -amount
proposed = current + delta if not np.isnan(current) else np.nan
feats["target_proposed_downtilt"] = proposed
desired = safe_float(geom_stats.get("geom_desired_downtilt_p50"), np.nan)
if not np.isnan(proposed) and not np.isnan(desired):
current_error = abs(current - desired)
proposed_error = abs(proposed - desired)
feats["target_proposed_tilt_error"] = proposed_error
feats["target_tilt_error_improvement"] = (
current_error - proposed_error
)
elif action == "azimuth":
required = safe_float(geom_stats.get("geom_azimuth_error_p50"), np.nan)
if not np.isnan(required):
feats["target_required_azimuth_change"] = required
feats["target_azimuth_change_abs_error"] = abs(amount - required)
elif action in {"inc_a3", "dec_a3"}:
current_half_db = safe_float(target_stats.get("cfg_a3_offset"), np.nan)
current_db = (
current_half_db / 2.0 if not np.isnan(current_half_db) else np.nan
)
delta = amount if action == "inc_a3" else -amount
feats["target_current_a3_offset_db"] = current_db
feats["target_proposed_a3_offset_db"] = (
current_db + delta if not np.isnan(current_db) else np.nan
)
elif action == "a2a5":
if "a2" in threshold_types:
current = safe_float(target_stats.get("cfg_a2_threshold"), np.nan)
feats["target_proposed_a2_threshold"] = (
current - amount if not np.isnan(current) else np.nan
)
if "a5_1" in threshold_types:
current = safe_float(target_stats.get("cfg_a5_threshold_1"), np.nan)
feats["target_proposed_a5_threshold_1"] = (
current - amount if not np.isnan(current) else np.nan
)
if "a5_2" in threshold_types:
current = safe_float(target_stats.get("cfg_a5_threshold_2"), np.nan)
feats["target_proposed_a5_threshold_2"] = (
current - amount if not np.isnan(current) else np.nan
)
elif action == "pdcch":
current = safe_float(target_stats.get("cfg_pdcch_symbols"), np.nan)
feats["target_current_pdcch_symbols"] = current
feats["target_pdcch_symbol_delta"] = (
pdcch_symbols - current
if not np.isnan(pdcch_symbols) and not np.isnan(current)
else np.nan
)
else:
feats["target_missing"] = 1.0
# Neighbor pair options: aggregate both cells.
if len(targets) >= 2:
vals = []
for key in targets:
r = roles.get(key, {})
vals.append(r.get("bad_serv_frac", 0.0) + r.get("bad_nei_frac", 0.0))
feats["pair_role_sum"] = float(np.sum(vals))
feats["pair_role_max"] = float(np.max(vals)) if vals else 0.0
# Option id as weak positional feature.
try:
feats["cid_num"] = float(int(cid[1:]))
except Exception:
feats["cid_num"] = 0.0
return clean_features(feats)
def build_prediction_context(s: Dict[str, Any]) -> Dict[str, Any]:
opts = get_options(s)
return {
"scenario_features": extract_scenario_features(s),
"cell_stats": extract_cell_stats(s),
"cell_geometry_stats": extract_cell_geometry_stats(s),
"is_multi": is_multi_task(s),
"options": opts,
"option_actions": {cid: label_to_action(label) for cid, label in opts.items()},
"option_semantics": {
cid: parse_option_semantics(label) for cid, label in opts.items()
},
}
def predict_template_probs(
model: Pipeline, s: Dict[str, Any], context: Optional[Dict[str, Any]] = None
) -> List[Tuple[str, float]]:
context = context or {}
x = context.get("scenario_features")
if x is None:
x = extract_scenario_features(s)
clf = model.named_steps["clf"]
proba = model.predict_proba([x])[0]
pairs = [(str(cls), float(p)) for cls, p in zip(clf.classes_, proba)]
pairs.sort(key=lambda z: z[1], reverse=True)
return pairs
def selector_score(model: Pipeline, feats: Dict[str, float]) -> float:
try:
proba = model.predict_proba([feats])[0]
classes = list(model.named_steps["clf"].classes_)
if 1 in classes:
return float(proba[classes.index(1)])
return 0.0
except Exception:
return 0.0
def selector_scores(model: Pipeline, features: List[Dict[str, float]]) -> List[float]:
if not features:
return []
try:
proba = model.predict_proba(features)
classes = list(model.named_steps["clf"].classes_)
if 1 not in classes:
return [0.0] * len(features)
idx = classes.index(1)
return [float(row[idx]) for row in proba]
except Exception:
return [0.0] * len(features)
def jsonable_number(value: Any, digits: int = 3) -> Optional[float]:
x = safe_float(value, np.nan)
if np.isnan(x):
return None
return round(float(x), digits)
def compact_cell_evidence(stats: Dict[str, float]) -> Dict[str, Optional[float]]:
keys = [
"bad_serv_frac",
"bad_nei_frac",
"bad_serv_rsrp_mean",
"bad_serv_sinr_mean",
"bad_serv_thr_mean",
"bad_nei_margin_mean",
"mr_serv_count",
"mr_nei_count",
"mr_thr_mean",
"cfg_power",
"cfg_max_power",
"cfg_downtilt",
"cfg_a3_offset",
"cfg_a2_threshold",
"cfg_a5_threshold_1",
"cfg_a5_threshold_2",
"cfg_pdcch_symbols",
"traffic_dl_prb",
"traffic_weak_cov",
"traffic_cce_dl",
"traffic_cce_succ_dl",
"geom_related_bad_count",
"geom_distance_mean",
"geom_azimuth_error_p50",
"geom_desired_downtilt_p50",
"geom_current_tilt_error_p50",
]
out: Dict[str, Optional[float]] = {}
for key in keys:
if key in stats:
value = jsonable_number(stats.get(key))
if value is not None:
out[key] = value
return out
def proposed_change_summary(
action: str, semantics: Dict[str, Any], target_stats: Dict[str, float]
) -> Dict[str, Optional[float]]:
amount = safe_float(semantics.get("amount"), np.nan)
if np.isnan(amount):
return {}
out: Dict[str, Optional[float]] = {}
if action in {"inc_power", "dec_power"}:
current = safe_float(target_stats.get("cfg_power"), np.nan)
max_power = safe_float(target_stats.get("cfg_max_power"), np.nan)
delta = amount if action == "inc_power" else -amount
proposed = current + delta if not np.isnan(current) else np.nan
out["current_power"] = jsonable_number(current)
out["proposed_power"] = jsonable_number(proposed)
out["power_margin_after"] = (
jsonable_number(max_power - proposed)
if not np.isnan(max_power) and not np.isnan(proposed)
else None
)
elif action in {"tilt_down", "tilt_up"}:
current = safe_float(target_stats.get("cfg_downtilt"), np.nan)
delta = amount if action == "tilt_down" else -amount
desired = safe_float(target_stats.get("geom_desired_downtilt_p50"), np.nan)
proposed = current + delta if not np.isnan(current) else np.nan
out["current_downtilt"] = jsonable_number(current)
out["proposed_downtilt"] = jsonable_number(proposed)
out["desired_downtilt_p50"] = jsonable_number(desired)
if not np.isnan(current) and not np.isnan(proposed) and not np.isnan(desired):
out["tilt_error_improvement"] = jsonable_number(
abs(current - desired) - abs(proposed - desired)
)
elif action == "azimuth":
required = safe_float(target_stats.get("geom_azimuth_error_p50"), np.nan)
out["required_azimuth_change_p50"] = jsonable_number(required)
out["azimuth_change_abs_error"] = (
jsonable_number(abs(amount - required)) if not np.isnan(required) else None
)
elif action in {"inc_a3", "dec_a3"}:
current_half_db = safe_float(target_stats.get("cfg_a3_offset"), np.nan)
current_db = current_half_db / 2.0 if not np.isnan(current_half_db) else np.nan
delta = amount if action == "inc_a3" else -amount
out["current_a3_offset_db"] = jsonable_number(current_db)
out["proposed_a3_offset_db"] = (
jsonable_number(current_db + delta) if not np.isnan(current_db) else None
)
elif action == "a2a5":
threshold_types = set(semantics.get("threshold_types") or [])
if "a2" in threshold_types:
current = safe_float(target_stats.get("cfg_a2_threshold"), np.nan)
out["current_a2_threshold"] = jsonable_number(current)
out["proposed_a2_threshold"] = (
jsonable_number(current - amount) if not np.isnan(current) else None
)
if "a5_1" in threshold_types:
current = safe_float(target_stats.get("cfg_a5_threshold_1"), np.nan)
out["current_a5_threshold_1"] = jsonable_number(current)
out["proposed_a5_threshold_1"] = (
jsonable_number(current - amount) if not np.isnan(current) else None
)
if "a5_2" in threshold_types:
current = safe_float(target_stats.get("cfg_a5_threshold_2"), np.nan)
out["current_a5_threshold_2"] = jsonable_number(current)
out["proposed_a5_threshold_2"] = (
jsonable_number(current - amount) if not np.isnan(current) else None
)
elif action == "pdcch":
current = safe_float(target_stats.get("cfg_pdcch_symbols"), np.nan)
proposed = safe_float(semantics.get("pdcch_symbols"), np.nan)
out["current_pdcch_symbols"] = jsonable_number(current)
out["proposed_pdcch_symbols"] = jsonable_number(proposed)
out["pdcch_symbol_delta"] = (
jsonable_number(proposed - current)
if not np.isnan(proposed) and not np.isnan(current)
else None
)
return {k: v for k, v in out.items() if v is not None}
def build_option_evidence(
selector_model: Pipeline,
s: Dict[str, Any],
context: Dict[str, Any],
top_templates: List[Tuple[str, float]],
ml_labels: List[str],
) -> Dict[str, Any]:
opts = context["options"]
option_actions = context["option_actions"]
option_semantics = context.get("option_semantics") or {}
roles = context.get("cell_stats") or {}
geom_roles = context.get("cell_geometry_stats") or {}
feature_rows: List[Dict[str, float]] = []
feature_keys: List[Tuple[str, str, str]] = []
seen_keys = set()
for tpl, _ in top_templates:
tpl_actions = set(template_to_actions(tpl))
for cid, label in opts.items():
action = option_actions.get(cid, "other")
if action == "other":
continue
key = (tpl, action, cid)
if key in seen_keys:
continue
seen_keys.add(key)
feature_keys.append(key)
feature_rows.append(option_feature_dict(s, cid, label, tpl, action, context))
score_by_option: Dict[str, float] = defaultdict(float)
score_detail: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
for (tpl, action, cid), score in zip(feature_keys, selector_scores(selector_model, feature_rows)):
score_by_option[cid] = max(score_by_option.get(cid, 0.0), float(score))
score_detail[cid].append(
{
"template": tpl,
"action": action,
"selector_score": round(float(score), 5),
}
)
rows = []
for cid, label in sorted(opts.items(), key=lambda x: int(x[0][1:])):
action = option_actions.get(cid, "other")
semantics = option_semantics.get(cid) or parse_option_semantics(label)
target = semantics.get("target_cell")
target_stats = dict(roles.get(target, {}) if target else {})
if target:
target_stats.update(geom_roles.get(target, {}))
row = {
"id": cid,
"label": label,
"action": action,
"target_cell": target,
"cells": semantics.get("cells") or [],
"amount": jsonable_number(semantics.get("amount")),
"unit": semantics.get("unit") or "",
"threshold_type": semantics.get("threshold_type") or "",
"pdcch_symbols": jsonable_number(semantics.get("pdcch_symbols")),
"towards_ue": bool(semantics.get("towards_ue")),
"selector_score": round(float(score_by_option.get(cid, 0.0)), 5),
"in_ml_prediction": cid in set(ml_labels),
"target_evidence": compact_cell_evidence(target_stats),
"proposed_change": proposed_change_summary(action, semantics, target_stats),
}
details = score_detail.get(cid, [])
if details:
row["score_detail"] = sorted(
details, key=lambda d: d["selector_score"], reverse=True
)[:3]
rows.append(row)
ambiguity_groups = []
grouped: Dict[Tuple[str, Tuple[str, ...]], List[Dict[str, Any]]] = defaultdict(list)
for row in rows:
cells = tuple(row.get("cells") or [])
if not cells or row["action"] in {"other", "server", "insufficient"}:
continue
grouped[(row["action"], cells)].append(row)
for (action, cells), group_rows in grouped.items():
if len(group_rows) < 2:
continue
amounts = {r.get("amount") for r in group_rows}
threshold_types = {r.get("threshold_type") for r in group_rows}
pdcch_symbols = {r.get("pdcch_symbols") for r in group_rows}
distinguishes = []
if len(amounts) > 1:
distinguishes.append("amount")
if len(threshold_types) > 1:
distinguishes.append("threshold_type")
if len(pdcch_symbols) > 1:
distinguishes.append("pdcch_symbols")
ambiguity_groups.append(
{
"action": action,
"cells": list(cells),
"candidate_ids": [r["id"] for r in group_rows],
"distinguishes": distinguishes,
"contains_ml_prediction": any(r["id"] in set(ml_labels) for r in group_rows),
}
)
return {
"task_type": "multiple-answer" if context.get("is_multi") else "single-answer",
"candidate_options": rows,
"ambiguity_groups": ambiguity_groups,
}
def option_diversity_key(row: Dict[str, Any]) -> Tuple[Any, ...]:
semantics = row.get("semantics") or {}
cells = tuple(semantics.get("cells") or [])
return (
row.get("action"),
cells[:1],
semantics.get("threshold_type") or "",
)
def select_ranked_options(
ranked: List[Dict[str, Any]], template: str, is_multi: bool
) -> List[str]:
ranked = sorted(ranked, key=lambda r: (float(r.get("score", 0.0)), r["cid"]), reverse=True)
if not ranked:
return []
if not is_multi:
return [ranked[0]["cid"]]
actions = template_to_actions(template)
target_count = len(actions) if 2 <= len(actions) <= 4 else 2
action_counts = Counter(actions)
selected: List[str] = []
used = set()
used_keys = set()
for action, n_needed in action_counts.items():
action_rows = [r for r in ranked if r["action"] == action and r["cid"] not in used]
for row in action_rows[:n_needed]:
selected.append(row["cid"])
used.add(row["cid"])
used_keys.add(option_diversity_key(row))
for allow_duplicate in (False, True):
for row in ranked:
if len(selected) >= target_count:
break
if row["cid"] in used:
continue
if row["action"] in {"server", "insufficient", "other"} and len(ranked) > target_count:
continue
key = option_diversity_key(row)
if not allow_duplicate and key in used_keys:
continue
selected.append(row["cid"])
used.add(row["cid"])
used_keys.add(key)
if len(selected) >= target_count:
break
return selected[:4]
def ranked_options_with_selector(
selector: Pipeline,
s: Dict[str, Any],
template: str,
context: Dict[str, Any],
) -> List[Dict[str, Any]]:
opts = context.get("options") or get_options(s)
option_actions = context.get("option_actions") or {
cid: label_to_action(label) for cid, label in opts.items()
}
option_semantics = context.get("option_semantics") or {}
ranked = []
for cid, label in opts.items():
action = option_actions.get(cid, "other")
if action == "other":
continue
feats = option_feature_dict(s, cid, label, template, action, context)
ranked.append(
{
"cid": cid,
"action": action,
"score": selector_score(selector, feats),
"semantics": option_semantics.get(cid) or parse_option_semantics(label),
}
)
return ranked
def ranked_options_from_scores(
context: Dict[str, Any],
template: str,
score_lookup: Dict[Tuple[str, str, str], float],
) -> List[Dict[str, Any]]:
opts = context.get("options") or {}
option_actions = context.get("option_actions") or {}
option_semantics = context.get("option_semantics") or {}
ranked = []
for cid, label in opts.items():
action = option_actions.get(cid, "other")
if action == "other":
continue
ranked.append(
{
"cid": cid,
"action": action,
"score": score_lookup.get((template, action, cid), 0.0),
"semantics": option_semantics.get(cid) or parse_option_semantics(label),
}
)
return ranked
def selected_rank_score(labels: List[str], ranked: List[Dict[str, Any]]) -> float:
if not labels:
return -1.0
scores = {row["cid"]: float(row.get("score", 0.0)) for row in ranked}
vals = [scores.get(cid, 0.0) for cid in labels]
return float(np.mean(vals) + 0.25 * np.min(vals))
def choose_options_with_selector(
selector: Pipeline,
s: Dict[str, Any],
template: str,
context: Optional[Dict[str, Any]] = None,
) -> List[str]:
context = context or {}
is_multi = context.get("is_multi")
if is_multi is None:
is_multi = is_multi_task(s)
ranked = ranked_options_with_selector(selector, s, template, context)
return select_ranked_options(ranked, template, bool(is_multi))
def choose_options_from_scores(
s: Dict[str, Any],
template: str,
context: Dict[str, Any],
score_lookup: Dict[Tuple[str, str, str], float],
) -> List[str]:
ranked = ranked_options_from_scores(context, template, score_lookup)
return select_ranked_options(
ranked, template, bool(context.get("is_multi", is_multi_task(s)))
)
def predict_labels(
template_model: Pipeline, selector_model: Pipeline, s: Dict[str, Any]
) -> Tuple[List[str], Dict[str, Any]]:
context = build_prediction_context(s)
probs = predict_template_probs(template_model, s, context)
candidates = []
for tpl, p in probs[:5]:
ranked = ranked_options_with_selector(selector_model, s, tpl, context)
labs = select_ranked_options(ranked, tpl, context["is_multi"])
if not labs:
continue
if context["is_multi"] and not (2 <= len(labs) <= 4):
continue
if (not context["is_multi"]) and len(labs) != 1:
continue
rank_score = selected_rank_score(labs, ranked)
candidates.append((rank_score + 0.05 * p, rank_score, p, tpl, labs))
if candidates:
candidates.sort(reverse=True, key=lambda x: x[0])
_, rank_score, p, tpl, labs = candidates[0]
else:
tpl, p = probs[0]
labs = choose_options_with_selector(selector_model, s, tpl, context)
rank_score = -1.0
return labs, {
"template": tpl,
"template_prob": p,
"selector_rank_score": rank_score,
"top_templates": probs[:5],
}
def predict_labels_batch(
template_model: Pipeline,
selector_model: Pipeline,
scenarios: List[Dict[str, Any]],
top_k_templates: int = 5,
) -> List[Tuple[List[str], Dict[str, Any]]]:
if not scenarios:
return []
contexts = [build_prediction_context(s) for s in scenarios]
scenario_features = [ctx["scenario_features"] for ctx in contexts]
template_proba = template_model.predict_proba(scenario_features)
template_classes = [str(cls) for cls in template_model.named_steps["clf"].classes_]
all_probs: List[List[Tuple[str, float]]] = []
feature_rows: List[Dict[str, float]] = []
feature_keys: List[Tuple[int, str, str, str]] = []
seen_keys = set()
for i, (s, ctx, row) in enumerate(zip(scenarios, contexts, template_proba)):
probs = [(cls, float(p)) for cls, p in zip(template_classes, row)]
probs.sort(key=lambda z: z[1], reverse=True)
top_probs = probs[:top_k_templates]
all_probs.append(top_probs)
opts = ctx["options"]
option_actions = ctx["option_actions"]
for tpl, _ in top_probs:
for cid, label in opts.items():
action = option_actions.get(cid, "other")
if action == "other":
continue
key = (i, tpl, action, cid)
if key in seen_keys:
continue
seen_keys.add(key)
feature_keys.append(key)
feature_rows.append(option_feature_dict(s, cid, label, tpl, action, ctx))
scores = selector_scores(selector_model, feature_rows)
score_lookups: List[Dict[Tuple[str, str, str], float]] = [
{} for _ in scenarios
]
for (i, tpl, action, cid), score in zip(feature_keys, scores):
score_lookups[i][(tpl, action, cid)] = score
out: List[Tuple[List[str], Dict[str, Any]]] = []
for s, ctx, probs, lookup in zip(scenarios, contexts, all_probs, score_lookups):
candidates = []
for tpl, p in probs:
ranked = ranked_options_from_scores(ctx, tpl, lookup)
labs = select_ranked_options(ranked, tpl, bool(ctx["is_multi"]))
if not labs:
continue
if ctx["is_multi"] and not (2 <= len(labs) <= 4):
continue
if (not ctx["is_multi"]) and len(labs) != 1:
continue
rank_score = selected_rank_score(labs, ranked)
candidates.append((rank_score + 0.05 * p, rank_score, p, tpl, labs))
if candidates:
candidates.sort(reverse=True, key=lambda x: x[0])
_, rank_score, p, tpl, labs = candidates[0]
else:
tpl, p = probs[0]
labs = choose_options_from_scores(s, tpl, ctx, lookup)
rank_score = -1.0
out.append(
(
labs,
{
"template": tpl,
"template_prob": p,
"selector_rank_score": rank_score,
"top_templates": probs,
},
)
)
return out
def write_submission(path: str, rows: List[Dict[str, str]]):
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
with open(path, "w", encoding="utf-8", newline="") as f:
w = csv.DictWriter(f, fieldnames=["ID", "Track A", "Track B"])
w.writeheader()
for r in rows:
w.writerow(r)
def result_csv_path(path: str) -> str:
if not path:
return "result.csv"
directory = os.path.dirname(path)
basename = os.path.basename(path)
if basename.lower() == "result.csv":
return path
fixed = os.path.join(directory or ".", "result.csv")
print(f"Output filename must be result.csv; writing to {fixed} instead of {path}")
return fixed
def write_debug(path: str, rows: List[Dict[str, Any]]):
if not path:
return
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(rows, f, indent=2)
def write_json(path: str, rows: List[Dict[str, Any]]):
if not path:
return
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(rows, f, indent=2, ensure_ascii=False)
def save_model_bundle(
path: str,
template_model: Pipeline,
selector_model: Pipeline,
metadata: Optional[Dict[str, Any]] = None,
):
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
bundle = {
"version": MODEL_BUNDLE_VERSION,
"template_model": template_model,
"selector_model": selector_model,
"metadata": metadata or {},
}
with open(path, "wb") as f:
pickle.dump(bundle, f, protocol=pickle.HIGHEST_PROTOCOL)
def load_model_bundle(path: str) -> Tuple[Pipeline, Pipeline, Dict[str, Any]]:
with open(path, "rb") as f:
bundle = pickle.load(f)
if not isinstance(bundle, dict):
raise ValueError(f"Unsupported model bundle format: {path}")
if bundle.get("version") != MODEL_BUNDLE_VERSION:
raise ValueError(
f"Unsupported model bundle version {bundle.get('version')} in {path}"
)
return (
bundle["template_model"],
bundle["selector_model"],
dict(bundle.get("metadata") or {}),
)