cetane-ysi-predictor / src /streamlit_app.py
SalZa2004's picture
Update src/streamlit_app.py
f716875 verified
import sys
import importlib.util
import streamlit as st
import pandas as pd
import numpy as np
from rdkit import Chem
from huggingface_hub import hf_hub_download
# -------------------------
# CONFIG
# -------------------------
REPO_ID_CN = "SalZa2004/CetaneV2"
REPO_ID_YSI = "SalZa2004/YSI_Predictor"
st.set_page_config(
page_title="Fuel Property Predictor",
page_icon="⛽",
layout="wide"
)
# -------------------------
# LOAD shared_features
# -------------------------
@st.cache_resource
def load_shared_features():
path = hf_hub_download(REPO_ID_CN, "shared_features.py")
spec = importlib.util.spec_from_file_location("shared_features", path)
shared = importlib.util.module_from_spec(spec)
spec.loader.exec_module(shared)
# Register module for joblib
sys.modules["shared_features"] = shared
sys.modules["main"] = shared
sys.modules["__main__"] = shared
globals()["FeatureSelector"] = shared.FeatureSelector
return shared
# -------------------------
# LOAD CN model
# -------------------------
@st.cache_resource
def load_model_cn(shared):
import joblib
model_path = hf_hub_download(REPO_ID_CN, "model.joblib")
selector_path = hf_hub_download(REPO_ID_CN, "selector.joblib")
model = joblib.load(model_path)
selector = joblib.load(selector_path)
return model, selector
# -------------------------
# LOAD YSI model
# -------------------------
@st.cache_resource
def load_model_ysi(shared):
import joblib
model_path = hf_hub_download(REPO_ID_YSI, "model.joblib")
selector_path = hf_hub_download(REPO_ID_YSI, "selector.joblib")
model = joblib.load(model_path)
selector = joblib.load(selector_path)
return model, selector
# -------------------------
# HELPERS
# -------------------------
def validate_smiles(smiles):
if pd.isna(smiles) or smiles == "":
return False
return Chem.MolFromSmiles(smiles) is not None
def predict_cn(smiles, model, selector, shared):
X = shared.featurize_df([smiles], return_df=False)
if X is None:
return None
X = selector.transform(X)
return float(model.predict(X)[0])
def predict_ysi(smiles, model, selector, shared):
X = shared.featurize_df([smiles], return_df=False)
if X is None:
return None
X = selector.transform(X)
return float(model.predict(X)[0])
# -------------------------
# UI
# -------------------------
def main():
st.title("⛽ Fuel Property Predictor")
st.markdown("""
Predict **Cetane Number (CN)** and **Yield Sooting Index (YSI)** from SMILES strings.
**Input options:**
- Single SMILES string
- CSV file with SMILES column
""")
try:
shared = load_shared_features()
model_cn, selector_cn = load_model_cn(shared)
model_ysi, selector_ysi = load_model_ysi(shared)
except Exception as e:
st.error(f"❌ Failed to load models: {e}")
st.stop()
tab1, tab2 = st.tabs(["Single Prediction", "Batch Prediction"])
# ---- SINGLE MODE ----
with tab1:
smiles = st.text_input("Enter SMILES:", placeholder="CCCCCCC")
if st.button("🔮 Predict", type="primary"):
if not validate_smiles(smiles):
st.error("Invalid SMILES.")
else:
cn = predict_cn(smiles, model_cn, selector_cn, shared)
ysi = predict_ysi(smiles, model_ysi, selector_ysi, shared)
st.metric("Cetane Number", f"{cn:.2f}")
st.metric("Yield Sooting Index", f"{ysi:.2f}")
# ---- BATCH MODE ----
with tab2:
file = st.file_uploader("CSV file", type=["csv"])
if file:
df = pd.read_csv(file)
smiles_cols = [c for c in df.columns if "smiles" in c.lower()]
if smiles_cols:
col = smiles_cols[0]
vals = df[col].tolist()
mask = [validate_smiles(s) for s in vals]
valid = [s for s, ok in zip(vals, mask) if ok]
# CN batch
X_cn = shared.featurize_df(valid, return_df=False)
Xs_cn = selector_cn.transform(X_cn)
preds_cn = model_cn.predict(Xs_cn)
# YSI batch
X_ysi = shared.featurize_df(valid, return_df=False)
Xs_ysi = selector_ysi.transform(X_ysi)
preds_ysi = model_ysi.predict(Xs_ysi)
# map to full size
out_cn, out_ysi = [], []
idx_cn, idx_ysi = 0, 0
for ok in mask:
if ok:
out_cn.append(preds_cn[idx_cn])
out_ysi.append(preds_ysi[idx_ysi])
idx_cn += 1
idx_ysi += 1
else:
out_cn.append(np.nan)
out_ysi.append(np.nan)
df["Predicted_CN"] = out_cn
df["Predicted_YSI"] = out_ysi
st.dataframe(df)
st.download_button(
"Download CSV",
df.to_csv(index=False),
"fuel_predictions.csv"
)
if __name__ == "__main__":
main()