SamChYe's picture
Publish EdgeEDA agent
aa677e3 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
from edgeeda.orfs.metrics import flatten_metrics, pick_first
@dataclass
class RewardComponents:
wns: Optional[float]
area: Optional[float]
power: Optional[float]
def compute_reward(
metrics_obj: Dict[str, Any],
wns_candidates: list[str],
area_candidates: list[str],
power_candidates: list[str],
weights: Dict[str, float],
) -> Tuple[Optional[float], RewardComponents, Dict[str, Any]]:
"""
Reward convention:
- Want larger WNS (less negative / more positive)
- Want smaller area, smaller power
Scalar reward = w_wns * WNS - w_area * log(area) - w_power * log(power)
(log makes it less sensitive across designs)
"""
flat = flatten_metrics(metrics_obj)
wns = pick_first(flat, wns_candidates)
if wns is None:
fallback_wns = [
"timing__setup__ws",
"finish__timing__setup__ws",
"route__timing__setup__ws",
"cts__timing__setup__ws",
"detailedplace__timing__setup__ws",
"floorplan__timing__setup__ws",
"globalplace__timing__setup__ws",
"globalroute__timing__setup__ws",
"placeopt__timing__setup__ws",
]
wns = pick_first(flat, fallback_wns)
area = pick_first(flat, area_candidates)
if area is None:
fallback_area = [
"synth__design__instance__area__stdcell",
"floorplan__design__instance__area__stdcell",
"globalplace__design__instance__area__stdcell",
"detailedplace__design__instance__area__stdcell",
"cts__design__instance__area__stdcell",
"finish__design__instance__area__stdcell",
"floorplan__design__die__area",
"placeopt__design__die__area",
"detailedplace__design__die__area",
"cts__design__die__area",
"globalroute__design__die__area",
"finish__design__die__area",
]
area = pick_first(flat, fallback_area)
power = pick_first(flat, power_candidates)
if wns is None and area is None and power is None:
return None, RewardComponents(None, None, None), flat
import math
w_wns = float(weights.get("wns", 1.0))
w_area = float(weights.get("area", 0.0))
w_power = float(weights.get("power", 0.0))
# Robustify logs
area_term = 0.0 if area is None else math.log(max(area, 1e-9))
power_term = 0.0 if power is None else math.log(max(power, 1e-9))
wns_term = 0.0 if wns is None else wns
reward = (w_wns * wns_term) - (w_area * area_term) - (w_power * power_term)
return reward, RewardComponents(wns, area, power), flat