mix / main.py
rrodden's picture
change range limit
fb8d24d verified
"""
MetaMix API β€” FastAPI wrapper around BOxCrete
=============================================
Designed to run as a Hugging Face Space (Docker SDK).
Port: 7860
Input variables exposed to users: cement, fly_ash, slag, water, hrwr.
Fixed at training-data means: fine_aggregate (1375), coarse_aggregate (510),
mrwr (0), material_source (0), temp_c (21).
"""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from boxcrete.models import SustainableConcreteModel
from boxcrete.utils import load_concrete_strength, get_bounds
logger = logging.getLogger(__name__)
# ── Global state ──────────────────────────────────────────────────────────────
_model: Optional[SustainableConcreteModel] = None
_data = None
CURVE_DAYS = list(range(1, 91))
STRENGTH_DAYS = [1, 3, 7, 14, 28, 56, 90]
DATA_URL = (
"https://raw.githubusercontent.com/facebookresearch/"
"SustainableConcrete/main/data/boxcrete_data.csv"
)
DATA_PATH = Path("/app/boxcrete_data.csv")
# Fixed at training-data means for columns not exposed as user inputs
# Only truly non-user columns are fixed
FIXED_DEFAULTS = {
"MRWR (kg/m3)": 0.0,
"Material Source": 0.0,
}
def _ensure_data() -> None:
if DATA_PATH.exists():
return
logger.info("Downloading boxcrete_data.csv from GitHub...")
import urllib.request
urllib.request.urlretrieve(DATA_URL, DATA_PATH)
logger.info("Download complete.")
def _fit_model() -> SustainableConcreteModel:
global _data
_ensure_data()
logger.info("Loading BOxCrete dataset...")
try:
_data = load_concrete_strength(data_path=str(DATA_PATH))
except TypeError:
import shutil
import site
for sp in site.getsitepackages():
dest = Path(sp) / "data" / "boxcrete_data.csv"
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(DATA_PATH, dest)
logger.info("Copied CSV to %s", dest)
break
_data = load_concrete_strength()
_data.bounds = get_bounds(_data.X_columns)
model = SustainableConcreteModel(strength_days=STRENGTH_DAYS)
logger.info("Fitting GWP model...")
model.fit_gwp_model(_data)
logger.info("Fitting strength model...")
model.fit_strength_model(_data)
logger.info("Model ready.")
return model
@asynccontextmanager
async def lifespan(app: FastAPI):
global _model
_model = _fit_model()
yield
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(title="MetaMix API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://pavements.design",
"https://www.pavements.design",
"http://localhost:3000",
],
allow_methods=["POST", "GET"],
allow_headers=["Content-Type"],
)
# ── Schemas ───────────────────────────────────────────────────────────────────
class MixRequest(BaseModel):
cement: float = Field(..., ge=0, le=895)
fly_ash: float = Field(0.0, ge=0, le=534)
slag: float = Field(0.0, ge=0, le=1199)
water: float = Field(..., ge=43, le=443)
hrwr: float = Field(0.0, ge=0, le=14, description="High-range water reducer kg/m3")
fine_aggregate: float = Field(1375.0, ge=0, le=2357)
coarse_aggregate: float = Field(510.0, ge=0, le=1356)
temp_c: float = Field(21.0, ge=4, le=22, description="Curing temperature C")
class StrengthPoint(BaseModel):
day: int
mean: float
lower: float
upper: float
class PredictionResponse(BaseModel):
gwp: float
gwp_lower: float
gwp_upper: float
strength_28d: float
strength_28d_lower: float
strength_28d_upper: float
strength_curve: list[StrengthPoint]
# ── Helpers ───────────────────────────────────────────────────────────────────
def _build_comp_tensor(req: MixRequest, comp_cols: list[str]) -> torch.Tensor:
user_inputs = {
"Cement (kg/m3)": req.cement,
"Fly Ash (kg/m3)": req.fly_ash,
"Slag (kg/m3)": req.slag,
"Water (kg/m3)": req.water,
"HRWR (kg/m3)": req.hrwr,
"Fine Aggregate (kg/m3)": req.fine_aggregate,
"Coarse Aggregates (kg/m3)": req.coarse_aggregate,
"Temp (C)": req.temp_c,
}
# Merge user inputs with fixed defaults (MRWR, Material Source stay fixed)
mapping = {**FIXED_DEFAULTS, **user_inputs}
values = [mapping.get(col, 0.0) for col in comp_cols]
return torch.tensor(values, dtype=torch.float64).unsqueeze(0)
def _posterior_stats(model, X: torch.Tensor) -> tuple[float, float, float]:
with torch.no_grad():
posterior = model.posterior(X)
mean = posterior.mean.squeeze().item()
std = posterior.variance.squeeze().item() ** 0.5
return mean, mean - 1.96 * std, mean + 1.96 * std
# ── Endpoints ─────────────────────────────────────────────────────────────────
@app.get("/health")
async def health():
return {"status": "ok", "model_ready": _model is not None}
@app.get("/columns")
async def columns():
if _data is None:
raise HTTPException(status_code=503, detail="Model not yet initialised")
return {
"X_columns": list(_data.X_columns),
"composition_columns": list(_data.X_columns[:-1]),
"time_column": _data.X_columns[-1],
}
@app.get("/data-ranges")
async def data_ranges():
if _data is None:
raise HTTPException(status_code=503, detail="Model not yet initialised")
comp_cols = list(_data.X_columns[:-1])
result = {}
try:
gX, gY, _, _ = _data.gwp_data
for i, col in enumerate(comp_cols):
result[col] = {
"min": round(float(gX[:, i].min()), 3),
"max": round(float(gX[:, i].max()), 3),
"mean": round(float(gX[:, i].mean()), 3),
}
except Exception as e:
result["error"] = str(e)
return result
@app.get("/sample-data")
async def sample_data():
if _data is None:
raise HTTPException(status_code=503, detail="Model not yet initialised")
try:
import pandas as pd
df = pd.read_csv(str(DATA_PATH))
return {
"columns": list(df.columns),
"first_5_rows": df.head(5).to_dict(orient="records"),
"strength_column_sample": df.iloc[:, -1].head(10).tolist() if "Strength" in df.columns[-1] or "strength" in df.columns[-1].lower() else "unknown",
"all_column_stats": df.describe().to_dict(),
}
except Exception as e:
return {"error": str(e)}
@app.post("/debug")
async def debug(req: MixRequest):
if _data is None:
raise HTTPException(status_code=503, detail="Model not yet initialised")
comp_cols = list(_data.X_columns[:-1])
comp = _build_comp_tensor(req, comp_cols)
values = comp.squeeze().tolist()
return {"mapping": dict(zip(comp_cols, values)), "tensor": values}
@app.post("/predict", response_model=PredictionResponse)
async def predict(req: MixRequest):
if _model is None or _data is None:
raise HTTPException(status_code=503, detail="Model not yet initialised")
comp_cols = list(_data.X_columns[:-1])
comp = _build_comp_tensor(req, comp_cols)
gwp_mean, gwp_lo, gwp_hi = _posterior_stats(_model.gwp_model, comp)
# The GWP model stores values as negatives internally (see notebook:
# title uses -gwp_pred). Negate to get positive kg CO2e/m3.
gwp_mean, gwp_lo, gwp_hi = -gwp_mean, -gwp_hi, -gwp_lo # note: lo/hi swap after negation
gwp_lo = max(gwp_lo, 0.0)
# Strength model was trained on psi; convert to MPa for output
PSI_TO_MPA = 1.0 / 145.038
curve: list[StrengthPoint] = []
str28_mean = str28_lo = str28_hi = 0.0
for day in CURVE_DAYS:
t = torch.tensor([[float(day)]], dtype=torch.float64)
X_t = torch.cat([comp, t], dim=-1)
mean, lo, hi = _posterior_stats(_model.strength_model, X_t)
mean = max(mean * PSI_TO_MPA, 0.0)
lo = max(lo * PSI_TO_MPA, 0.0)
hi = hi * PSI_TO_MPA
curve.append(StrengthPoint(
day=day,
mean=round(mean, 2),
lower=round(lo, 2),
upper=round(hi, 2),
))
if day == 28:
str28_mean, str28_lo, str28_hi = mean, lo, hi
return PredictionResponse(
gwp=round(gwp_mean, 2),
gwp_lower=round(gwp_lo, 2),
gwp_upper=round(gwp_hi, 2),
strength_28d=round(str28_mean, 2),
strength_28d_lower=round(str28_lo, 2),
strength_28d_upper=round(str28_hi, 2),
strength_curve=curve,
)
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)