ConcPre / inference.py
Heterogeneity2025's picture
Upload 32 files
7a1376e verified
Raw
History Blame Contribute Delete
23.8 kB
"""Inference layer for the deployed concrete-strength app.
Single source of truth for:
* loading the trained hierarchical GNN + tabular-ANN checkpoints,
* Function 1 (forward): a (possibly incomplete) mix design -> compressive strength,
* Function 2 (inverse): a target strength -> several detailed mix designs.
Imported by ``streamlit_app.py`` and runnable as a CLI for verification:
# forward (fields left out are treated as "not measured")
python app/inference.py predict --cement 500 --water 175 --age 28 \
--coarse 950 --fine 750
# inverse
python app/inference.py suggest --target 120 --k 5
python app/inference.py suggest --target 40 --k 5 --no-fibre
The checkpoints were trained with the *curing-only* schema and the
``mortar_capacity`` strength head; that is NOT stored in the checkpoint, so we
rebuild the model with exactly those settings here (see CURING_ONLY_* below).
"""
from __future__ import annotations
import argparse
import json
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Sequence
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
# --- make the concrete_gnn package importable whether we run from the repo
# (package lives in ../Hybrid) or from a self-contained HF Space bundle
# (package copied next to this file). -------------------------------------
_HERE = Path(__file__).resolve().parent
_PKG_ROOT = None
for _cand in (_HERE, _HERE.parent / "Hybrid", _HERE.parent):
if (_cand / "concrete_gnn" / "__init__.py").exists():
sys.path.insert(0, str(_cand))
_PKG_ROOT = _cand
break
if _PKG_ROOT is None:
raise RuntimeError(
"Missing local concrete_gnn package. Deploy the contents of "
"app/space_bundle, not app/ alone; the Space root must contain "
"concrete_gnn/__init__.py next to streamlit_app.py."
)
from concrete_gnn import ( # noqa: E402
ConcreteMixDataset,
IntegratedMultiscaleModel,
SchemaSpec,
Standardizer,
StrengthHead,
TABULAR_DIM,
TabularANN,
collate_multiscale,
collect_strength_predictions,
fit_encoder_standardizers,
strength_column,
)
# ---------------------------------------------------------------------------
# Schema used at training time (curing-only globals). MUST match train_real_data.py.
# ---------------------------------------------------------------------------
CURING_ONLY_GLOBALS = (
"relative_humidity", "temperature_C", "curing_age_days",
"_placeholder_0", "_placeholder_1",
)
CURING_ONLY_MORTAR_GLOBALS = (
"mortar_curing_relative_humidity", "mortar_curing_temperature_C",
"mortar_curing_age_days", "_placeholder_0", "_placeholder_1",
)
def make_schema() -> SchemaSpec:
return SchemaSpec(glob=CURING_ONLY_GLOBALS, mortar_global=CURING_ONLY_MORTAR_GLOBALS)
# ---------------------------------------------------------------------------
# Column layout (normalized names the ConcreteMixDataset loader expects).
# ---------------------------------------------------------------------------
TARGET_COL = "compressive_strength_mpa"
# Binder + water + aggregate + fibre amounts: a missing value means "none of
# this component", i.e. a genuine 0.0 (not a masked/unknown channel).
AMOUNT_COLS = [
"cement_kg_m3", "slag_kg_m3", "fly_ash_kg_m3", "silica_fume_kg_m3",
"metakaolin_kg_m3", "limestone_powder_kg_m3", "other_scm_kg_m3",
"water_kg_m3", "superplasticizer_kg_m3",
"coarse_aggregate_kg_m3", "fine_aggregate_kg_m3",
"fibre_content_kg_m3", "fibre_length_mm", "fibre_diameter_mm",
"fibre_tensile_strength_mpa", "fibre_modulus_gpa",
]
# Maskable descriptors: a missing value means "not reported" -> NaN -> masked.
MASKABLE_COLS = [
"max_coarse_aggregate_size_mm", "max_fine_aggregate_size_mm",
"curing_temperature_c",
"cement_CaO_pct", "cement_SiO2_pct", "cement_Al2O3_pct", "cement_Fe2O3_pct",
"cement_MgO_pct", "cement_SO3_pct", "cement_alkali_pct", "cement_LOI_pct",
"scm_CaO_pct", "scm_SiO2_pct", "scm_Al2O3_pct", "scm_Fe2O3_pct",
"scm_MgO_pct", "scm_LOI_pct",
"cement_grade_mpa", "curing_humidity_pct", "specimen_size_mm",
]
CATEGORICAL_COLS = ["cement_type_norm", "fibre_type_norm", "curing_regime_norm"]
# Supplementary cementitious materials (used by the inverse "exclude SCMs" filter).
SCM_COLS = [
"slag_kg_m3", "fly_ash_kg_m3", "silica_fume_kg_m3",
"metakaolin_kg_m3", "limestone_powder_kg_m3", "other_scm_kg_m3",
]
# age has its own default (0 d would imply no strength); everything else 0.0.
AGE_COL = "age_days"
DEFAULT_AGE = 28.0
# Design knobs the inverse "refine" step is allowed to perturb.
REFINE_KNOBS = [
"cement_kg_m3", "water_kg_m3", "superplasticizer_kg_m3",
"slag_kg_m3", "fly_ash_kg_m3", "silica_fume_kg_m3",
]
# Columns shown / clustered for inverse diversity (the design vector).
DESIGN_COLS = AMOUNT_COLS + [AGE_COL]
# ---------------------------------------------------------------------------
# Embodied-carbon inventory.
# Indicative cradle-to-gate (A1-A3) emission factors, kg CO2e per kg of
# material, drawn from the ICE database / EPD literature. Cement dominates;
# SCMs are treated as low-burden industrial by-products; aggregates and water
# are small. These are editable defaults — users should substitute their own
# suppliers' EPD values where available.
# ---------------------------------------------------------------------------
EMBODIED_CARBON_FACTORS: Dict[str, float] = {
"cement_kg_m3": 0.90,
"slag_kg_m3": 0.083,
"fly_ash_kg_m3": 0.008,
"silica_fume_kg_m3": 0.014,
"metakaolin_kg_m3": 0.330,
"limestone_powder_kg_m3": 0.017,
"other_scm_kg_m3": 0.050,
"water_kg_m3": 0.0003,
"superplasticizer_kg_m3": 1.880,
"coarse_aggregate_kg_m3": 0.005,
"fine_aggregate_kg_m3": 0.005,
"fibre_content_kg_m3": 1.500,
}
# Curing-process carbon: extra kg CO2e per m³ from the energy used to accelerate
# curing. Standard (ambient/moist) curing adds nothing; heated regimes add the
# energy needed to raise and hold temperature. Values are LCA-literature-based
# central estimates and remain editable in the app (fuel mix, cycle duration and
# temperature cause a wide real-world spread, roughly 10-40 kg CO2e/m³ for steam):
# * steam ~25 — Liu et al. precast steam-curing review: ~10 L boiler oil/h
# x 2.50 kg CO2/L over a typical ~9 h cycle ≈ 25 kg/m³; and a
# block LCA of "up to ~10 kg CO2/tonne" ≈ 24 kg/m³ at
# 2400 kg/m³ (PMC9024602; ScienceDirect S0959652619332299).
# * heat ~20 — electric/oven heat curing: +~12 kWh/m³ at 45°C vs ambient,
# and a heated chamber ~20 kg CO2e/m³ at a ~0.2 kg CO2/kWh
# grid (PMC10053802).
# * autoclave ~50 — high-pressure saturated steam (~180°C), the most
# energy-intensive regime; ~2x atmospheric steam (direct
# dense-concrete figures are scarce — least-certain value).
CURING_CARBON_FACTORS: Dict[str, float] = {
"standard": 0.0,
"steam": 25.0,
"heat": 20.0,
"autoclave": 50.0,
"other": 0.0,
}
def embodied_carbon(
mix: Dict[str, float],
factors: Optional[Dict[str, float]] = None,
curing_factors: Optional[Dict[str, float]] = None,
):
"""Cradle-to-gate embodied carbon (kg CO2e per m³ of concrete) for one mix.
Sums per-ingredient material carbon plus a curing-process term keyed on the
mix's ``curing_regime_norm``. Returns ``(total, breakdown)`` where
``breakdown`` is a DataFrame (amount, factor, embodied carbon). Missing/NaN
amounts count as 0.
"""
factors = EMBODIED_CARBON_FACTORS if factors is None else factors
curing_factors = CURING_CARBON_FACTORS if curing_factors is None else curing_factors
rows, total = [], 0.0
for col, fac in factors.items():
amt = mix.get(col, 0.0)
if amt is None or (isinstance(amt, float) and np.isnan(amt)):
amt = 0.0
amt, fac = float(amt), float(fac)
co2 = amt * fac
total += co2
rows.append({
"ingredient": col,
"amount_kg_m3": round(amt, 2),
"factor_kgco2e_per_kg": fac,
"embodied_carbon_kgco2e_m3": round(co2, 2),
})
# Curing-process energy (per m³, not per kg).
regime = mix.get("curing_regime_norm")
cure_co2 = float(curing_factors.get(regime, 0.0) or 0.0) if regime else 0.0
if cure_co2:
total += cure_co2
rows.append({
"ingredient": f"curing ({regime})",
"amount_kg_m3": np.nan,
"factor_kgco2e_per_kg": np.nan,
"embodied_carbon_kgco2e_m3": round(cure_co2, 2),
})
return total, pd.DataFrame(rows)
def build_input_frame(mixes: Sequence[Dict[str, float]]) -> pd.DataFrame:
"""Turn a list of partial mix dicts into a full normalized DataFrame.
Unspecified amount columns -> 0.0 (component absent); unspecified maskable
descriptors / chemistry -> NaN (not reported -> masked); unspecified
categoricals -> NaN (default bucket); age defaults to 28 d.
"""
rows = []
for mix in mixes:
row: Dict[str, object] = {}
for c in AMOUNT_COLS:
row[c] = float(mix.get(c, 0.0) or 0.0)
row[AGE_COL] = float(mix.get(AGE_COL, DEFAULT_AGE) or DEFAULT_AGE)
for c in MASKABLE_COLS:
v = mix.get(c, None)
row[c] = np.nan if v is None or v == "" else float(v)
for c in CATEGORICAL_COLS:
v = mix.get(c, None)
row[c] = np.nan if v is None or v == "" else str(v)
row[TARGET_COL] = float(mix.get(TARGET_COL, 0.0) or 0.0) # placeholder
rows.append(row)
return pd.DataFrame(rows)
def add_derived(df: pd.DataFrame) -> pd.DataFrame:
"""Append water/binder ratio and SCM fraction for display."""
df = df.copy()
binder = (
df["cement_kg_m3"] + df["slag_kg_m3"] + df["fly_ash_kg_m3"]
+ df["silica_fume_kg_m3"] + df["metakaolin_kg_m3"]
+ df["limestone_powder_kg_m3"] + df["other_scm_kg_m3"]
).clip(lower=1.0)
df["water_binder_ratio"] = (df["water_kg_m3"] / binder).round(3)
df["scm_fraction"] = ((binder - df["cement_kg_m3"]) / binder).round(3)
return df
# ---------------------------------------------------------------------------
# Predictor
# ---------------------------------------------------------------------------
def _detect_head_kind(state: dict) -> str:
keys = list(state.keys())
if any(k.startswith("strength_head.mortar_eff") for k in keys):
return "mortar_capacity"
if any(k.startswith("strength_head.") for k in keys):
return "physics"
return "free"
@dataclass
class Predictor:
checkpoint_dir: Path
device: torch.device = field(default_factory=lambda: torch.device("cpu"))
seed: int = 17
config: dict = field(default_factory=dict)
def __post_init__(self) -> None:
self.schema = make_schema()
cfg_path = Path(self.checkpoint_dir) / "model_config.json"
if cfg_path.exists():
self.config = json.loads(cfg_path.read_text())
self.gnn, self.std = self._load("hierarchical.pt", kind="gnn")
self.tabular, _ = self._load("tabular_ann.pt", kind="tabular")
self.bounds: dict = self.config.get("feature_bounds", {})
def _load(self, fname: str, kind: str):
path = Path(self.checkpoint_dir) / fname
ck = torch.load(path, map_location=self.device, weights_only=False)
std = Standardizer(
mean=torch.tensor(float(ck["target_mean"])),
std=torch.tensor(float(ck["target_std"])),
)
state = ck["base_model_state_dict"]
if kind == "gnn":
head_kind = self.config.get("strength_head_kind") or _detect_head_kind(state)
base = IntegratedMultiscaleModel(schema=self.schema, strength_head_kind=head_kind)
else:
base = TabularANN(in_dim=TABULAR_DIM, schema=self.schema)
base.load_state_dict(state)
model = StrengthHead(base, std).to(self.device).eval()
return model, std
# ---- forward ---- #
@torch.no_grad()
def predict_df(self, df: pd.DataFrame, which=("gnn", "tabular")) -> Dict[str, np.ndarray]:
df = df.copy()
if TARGET_COL not in df.columns:
df[TARGET_COL] = 0.0
ds = ConcreteMixDataset(df, None, 0.0, self.seed, self.std, schema=self.schema)
dl = DataLoader(ds, batch_size=128, shuffle=False, collate_fn=collate_multiscale)
models = {"gnn": self.gnn, "tabular": self.tabular}
out: Dict[str, np.ndarray] = {}
for name in which:
pred, _ = collect_strength_predictions(models[name], dl, self.device)
out[name] = pred.numpy()
return out
def predict_strength(self, mix: Dict[str, float]) -> Dict[str, float]:
"""Single partial mix -> {'gnn': MPa, 'tabular': MPa}."""
df = build_input_frame([mix])
res = self.predict_df(df)
return {k: float(v[0]) for k, v in res.items()}
def age_curve(self, mix: Dict[str, float], ages=None) -> pd.DataFrame:
"""Predict strength across curing ages holding the rest of the mix fixed."""
if ages is None:
ages = [3, 7, 28, 56]
mixes = [{**mix, AGE_COL: float(a)} for a in ages]
res = self.predict_df(build_input_frame(mixes))
return pd.DataFrame({"age_days": ages, "gnn_mpa": res["gnn"], "tabular_mpa": res["tabular"]})
def out_of_range(self, mix: Dict[str, float]) -> List[str]:
"""Names of supplied fields that fall outside the training p1-p99 range."""
flags = []
for c, v in mix.items():
b = self.bounds.get(c)
if b and v not in (None, "") and isinstance(v, (int, float)):
if v < b["p01"] or v > b["p99"]:
flags.append(c)
return flags
# ---- inverse ---- #
def suggest_mixes(
self,
target: float,
index: pd.DataFrame,
k: int = 5,
tol: Optional[float] = None,
allow_fibre: bool = True,
require_coarse_aggregate: bool = True,
exclude_scms: Optional[List[str]] = None,
curing_regimes: Optional[List[str]] = None,
age: Optional[float] = None,
domain: str = "any", # "any" | "uhpc" | "normal"
refine: bool = True,
) -> pd.DataFrame:
df = index.copy()
if not allow_fibre:
df = df[df["fibre_content_kg_m3"].fillna(0) <= 0]
# Most practical concrete contains coarse aggregate; binder/sand-only
# mixes (UHPC, some mortars) are misleading as ordinary suggestions, so
# require it by default. Users targeting UHPC can switch this off.
if require_coarse_aggregate and "coarse_aggregate_kg_m3" in df.columns:
df = df[pd.to_numeric(df["coarse_aggregate_kg_m3"], errors="coerce").fillna(0) > 0]
# Drop mixes that contain any SCM the user can't use.
present = [c for c in (exclude_scms or []) if c in df.columns]
if present:
df = df[df[present].fillna(0).sum(axis=1) <= 0]
# Keep only the curing regimes the user prefers. Records with no recorded
# regime (most non-UHPC datasets leave it blank) are assumed to be
# standard moist/ambient curing — otherwise any preference would collapse
# the pool to UHPC, the only dataset that documents a curing regime.
if curing_regimes and "curing_regime_norm" in df.columns:
regime = df["curing_regime_norm"].fillna("standard").replace("", "standard")
df = df[regime.isin(curing_regimes)]
if domain == "uhpc":
df = df[df["source"] == "UHPC"]
elif domain == "normal":
df = df[df["source"] != "UHPC"]
if age is not None:
df = df[np.isclose(pd.to_numeric(df["age_days"], errors="coerce"), age)]
if df.empty:
return df
df = df.copy()
df["err"] = (df["pred_gnn"] - target).abs()
# Keep the candidate pool within a tolerance band so even the most
# "diverse" pick stays near the target; widen if too few rows qualify.
base_tol = tol if tol is not None else max(5.0, 0.10 * target)
need = max(k * 4, 12)
pool, mult = df[df["err"] <= base_tol], 1.0
while len(pool) < need and mult < 8:
mult *= 2
pool = df[df["err"] <= base_tol * mult]
if pool.empty:
pool = df.nsmallest(need, "err")
pool = pool.sort_values("err").head(200)
seeds = self._diversify(pool, k)
if refine and len(seeds) > 0:
seeds = self._refine(seeds, target)
out = self._assemble(seeds, target)
# Return the k closest-to-target after refinement.
order = out["pred_gnn"].sub(target).abs().sort_values().index
return out.loc[order].head(k).reset_index(drop=True)
def _diversify(self, df: pd.DataFrame, k: int) -> pd.DataFrame:
"""Greedy max-min selection over the normalized design vector."""
if len(df) <= k:
return df
X = df[DESIGN_COLS].fillna(0.0).to_numpy(dtype=float)
mu, sigma = X.mean(0), X.std(0) + 1e-6
Z = (X - mu) / sigma
chosen = [0] # closest-to-target row is first (df is sorted by err)
while len(chosen) < k:
d = np.min(
np.linalg.norm(Z[:, None, :] - Z[chosen][None, :, :], axis=2), axis=1
)
d[chosen] = -1.0
chosen.append(int(np.argmax(d)))
return df.iloc[chosen]
def _refine(self, seeds: pd.DataFrame, target: float, n_per_seed: int = 32) -> pd.DataFrame:
"""Perturb a few knobs per seed, predict, keep the closest-to-target variant."""
rng = np.random.default_rng(self.seed)
candidates: List[Dict[str, float]] = []
owners: List[int] = []
seed_rows = seeds.reset_index(drop=True)
for i, row in seed_rows.iterrows():
base = {c: (None if pd.isna(row[c]) else row[c]) for c in
AMOUNT_COLS + [AGE_COL] + MASKABLE_COLS + CATEGORICAL_COLS}
candidates.append(dict(base)); owners.append(i) # keep the seed itself
for _ in range(n_per_seed):
cand = dict(base)
for knob in REFINE_KNOBS:
cur = float(base.get(knob) or 0.0)
if cur <= 0 and knob in ("slag_kg_m3", "fly_ash_kg_m3", "silica_fume_kg_m3"):
continue # don't invent an SCM that wasn't there
factor = float(rng.uniform(0.85, 1.15))
cand[knob] = self._clip(knob, cur * factor)
candidates.append(cand); owners.append(i)
preds = self.predict_df(build_input_frame(candidates), which=("gnn",))["gnn"]
owners = np.asarray(owners)
best_rows = []
for i in range(len(seed_rows)):
sel = np.where(owners == i)[0]
err = np.abs(preds[sel] - target)
best = sel[int(np.argmin(err))]
best_rows.append({**candidates[best], "pred_gnn": float(preds[best]),
"source": seed_rows.loc[i].get("source", "refined"),
"measured": seed_rows.loc[i].get("measured", np.nan)})
return pd.DataFrame(best_rows)
def _clip(self, col: str, value: float) -> float:
b = self.bounds.get(col)
if b:
return float(min(max(value, b["min"]), b["max"]))
return float(max(value, 0.0))
def _assemble(self, seeds: pd.DataFrame, target: float) -> pd.DataFrame:
df = seeds.copy().reset_index(drop=True)
if "pred_gnn" not in df.columns:
df["pred_gnn"] = self.predict_df(build_input_frame(df.to_dict("records")),
which=("gnn",))["gnn"]
df = add_derived(df)
for c in DESIGN_COLS: # tidy kg/mm/day values for display
if c in df.columns:
df[c] = pd.to_numeric(df[c], errors="coerce").round(1)
df["pred_gnn"] = df["pred_gnn"].round(1)
df["target_mpa"] = float(target)
front = (["pred_gnn", "target_mpa", "measured", "source",
"water_binder_ratio"] + DESIGN_COLS)
cols = [c for c in front if c in df.columns] + \
[c for c in df.columns if c not in front]
return df[cols]
# ---------------------------------------------------------------------------
# CLI (verification)
# ---------------------------------------------------------------------------
def _default_ckpt_dir() -> Path:
for c in (_HERE / "checkpoints_full_rich",
_HERE.parent / "Hybrid" / "outputs" / "checkpoints_full_rich"):
if (c / "hierarchical.pt").exists():
return c
return _HERE / "checkpoints_full_rich"
def _default_index() -> Optional[Path]:
for c in (_HERE / "inverse_index.csv",):
if c.exists():
return c
return None
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
sub = ap.add_subparsers(dest="cmd", required=True)
p = sub.add_parser("predict", help="forward: mix -> strength")
p.add_argument("--checkpoint-dir", default=str(_default_ckpt_dir()))
p.add_argument("--cement", type=float); p.add_argument("--slag", type=float)
p.add_argument("--fly-ash", type=float); p.add_argument("--silica-fume", type=float)
p.add_argument("--water", type=float); p.add_argument("--sp", type=float)
p.add_argument("--coarse", type=float); p.add_argument("--fine", type=float)
p.add_argument("--age", type=float, default=28.0)
s = sub.add_parser("suggest", help="inverse: strength -> mixes")
s.add_argument("--checkpoint-dir", default=str(_default_ckpt_dir()))
s.add_argument("--index", default=str(_default_index() or ""))
s.add_argument("--target", type=float, required=True)
s.add_argument("--k", type=int, default=5)
s.add_argument("--tol", type=float, default=None)
s.add_argument("--no-fibre", action="store_true")
s.add_argument("--no-scm", action="store_true")
s.add_argument("--no-refine", action="store_true")
args = ap.parse_args()
pred = Predictor(Path(args.checkpoint_dir))
if args.cmd == "predict":
mix = {
"cement_kg_m3": args.cement, "slag_kg_m3": args.slag,
"fly_ash_kg_m3": args.fly_ash, "silica_fume_kg_m3": args.silica_fume,
"water_kg_m3": args.water, "superplasticizer_kg_m3": args.sp,
"coarse_aggregate_kg_m3": args.coarse, "fine_aggregate_kg_m3": args.fine,
"age_days": args.age,
}
mix = {k: v for k, v in mix.items() if v is not None}
res = pred.predict_strength(mix)
print(f"input: {mix}")
print(f"GNN : {res['gnn']:.1f} MPa")
print(f"tabular : {res['tabular']:.1f} MPa")
flags = pred.out_of_range(mix)
if flags:
print(f"[warning] outside training range: {flags}")
else:
if not args.index:
ap.error("suggest needs --index (run app/build_inverse_index.py first)")
index = pd.read_csv(args.index)
out = pred.suggest_mixes(
args.target, index, k=args.k, tol=args.tol,
allow_fibre=not args.no_fibre,
exclude_scms=SCM_COLS if args.no_scm else None,
refine=not args.no_refine,
)
pd.set_option("display.width", 220, "display.max_columns", 40)
print(out.to_string(index=False))
if __name__ == "__main__":
main()