nithyanatarajan's picture
Upload folder using huggingface_hub
75a0982 verified
import streamlit as st
import pandas as pd
import numpy as np
from huggingface_hub import hf_hub_download
import joblib
import os
# ---------- Page Config ----------
st.set_page_config(
page_title="Predictive Maintenance",
page_icon="🔧",
layout="wide",
initial_sidebar_state="collapsed",
)
# ---------- Custom CSS ----------
st.markdown(
"""
<style>
.block-container {
padding-top: 2rem;
}
div.stButton > button {
background-color: #007bff;
color: white;
border: none;
padding: 0.75rem 3rem;
border-radius: 0.5rem;
cursor: pointer;
}
div.stButton > button:hover {
background-color: #0069d9;
}
</style>
""",
unsafe_allow_html=True
)
# ---------- Constants ----------
THRESHOLD = 0.4
REQUIRED_COLS = [
"engine_rpm", "lub_oil_pressure", "fuel_pressure",
"coolant_pressure", "lub_oil_temp", "coolant_temp",
]
# ---------- Helper Functions ----------
def slugify(name: str) -> str:
"""Convert name to HF-compatible slug (underscores to hyphens)."""
return name.replace("_", "-")
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
"""Compute engineered features from raw sensor columns."""
df = df.copy()
df["rpm_x_fuel_pressure"] = df["engine_rpm"] * df["fuel_pressure"]
df["rpm_bins"] = np.where(
df["engine_rpm"] < 300, 0,
np.where(df["engine_rpm"] <= 1500, 1, 2)
)
df["oil_health_index"] = np.where(
df["lub_oil_temp"] > 0,
df["lub_oil_pressure"] / df["lub_oil_temp"],
0,
)
return df
# ---------- Load Model (full pipeline with preprocessor) ----------
hf_username = os.getenv("HF_USERNAME")
hf_model_name = slugify(os.getenv("HF_MODEL_NAME", "predictive-maintenance-model"))
model_path = hf_hub_download(
repo_id=f"{hf_username}/{hf_model_name}",
filename="best_engine_maintenance_model.joblib"
)
model = joblib.load(model_path)
# ---------- App Header ----------
st.title("🔧 Engine Predictive Maintenance")
st.write("Predict whether an engine requires maintenance based on sensor readings.")
st.markdown("---")
# ---------- Tabs ----------
tab_single, tab_batch = st.tabs(["Single Prediction", "Batch Prediction"])
# ===================== Single Prediction =====================
with tab_single:
st.subheader("📊 Sensor Readings")
col1, col2, col3 = st.columns(3, gap="medium")
with col1:
st.markdown("**Engine Performance**")
engine_rpm = st.number_input("Engine RPM", min_value=0, max_value=3000, value=1500)
fuel_pressure = st.number_input("Fuel Pressure (bar)", min_value=0.0, max_value=25.0, value=6.5, step=0.1)
with col2:
st.markdown("**Lubrication System**")
lub_oil_pressure = st.number_input("Lub Oil Pressure (bar)", min_value=0.0, max_value=10.0, value=3.5, step=0.1)
lub_oil_temp = st.number_input("Lub Oil Temperature (°C)", min_value=0.0, max_value=150.0, value=85.0, step=1.0)
with col3:
st.markdown("**Cooling System**")
coolant_pressure = st.number_input("Coolant Pressure (bar)", min_value=0.0, max_value=10.0, value=2.0, step=0.1)
coolant_temp = st.number_input("Coolant Temperature (°C)", min_value=0.0, max_value=200.0, value=90.0, step=1.0)
st.markdown("---")
input_data = add_derived_features(pd.DataFrame([{
"engine_rpm": engine_rpm,
"lub_oil_pressure": lub_oil_pressure,
"fuel_pressure": fuel_pressure,
"coolant_pressure": coolant_pressure,
"lub_oil_temp": lub_oil_temp,
"coolant_temp": coolant_temp,
}]))
st.subheader("📦 Feature Preview")
with st.expander("Click to expand (includes derived features)", expanded=False):
cols = st.columns(3)
for i, (field, value) in enumerate(input_data.iloc[0].items()):
with cols[i % 3]:
display_value = f"{value:.4f}" if isinstance(value, float) else value
st.metric(label=field, value=display_value)
if st.button("Predict Maintenance Need"):
probability = model.predict_proba(input_data)[0, 1]
prediction = 1 if probability >= THRESHOLD else 0
st.markdown("---")
st.subheader("Prediction Result")
if prediction == 1:
st.error(f"⚠️ **Maintenance Required** (Failure probability: {probability:.2%})")
st.write("The engine shows signs of degradation. Schedule maintenance soon.")
else:
st.success(f"✅ **Normal Operation** (Failure probability: {probability:.2%})")
st.write("The engine is operating within normal parameters.")
st.balloons()
# ===================== Batch Prediction =====================
with tab_batch:
st.subheader("📁 Batch Prediction")
st.write("Upload a CSV with sensor readings to predict maintenance needs for multiple engines.")
# Template download
template_df = pd.DataFrame([{
"engine_rpm": 1500, "lub_oil_pressure": 3.5, "fuel_pressure": 6.5,
"coolant_pressure": 2.0, "lub_oil_temp": 85.0, "coolant_temp": 90.0,
}])
with st.expander("Required CSV format"):
st.dataframe(template_df, use_container_width=True)
st.download_button(
"Download template CSV",
template_df.to_csv(index=False),
file_name="maintenance_batch_template.csv",
mime="text/csv",
)
uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
if uploaded_file is not None:
try:
batch_df = pd.read_csv(uploaded_file)
except Exception as e:
st.error(f"Failed to read CSV: {e}")
st.stop()
# Validate required columns
missing = set(REQUIRED_COLS) - set(batch_df.columns)
if missing:
st.error(f"Missing required columns: {', '.join(sorted(missing))}")
st.stop()
st.info(f"Loaded {len(batch_df)} rows")
if st.button("Predict Batch"):
features = add_derived_features(batch_df[REQUIRED_COLS])
probabilities = model.predict_proba(features)[:, 1]
predictions = (probabilities >= THRESHOLD).astype(int)
results = batch_df.copy()
results["failure_probability"] = probabilities.round(2)
results["prediction"] = np.where(predictions == 1, "Maintenance", "Normal")
total = len(predictions)
maint_count = int((predictions == 1).sum())
normal_count = int((predictions == 0).sum())
st.markdown(
f'<p style="font-size:1.4rem">'
f"⚠️ <b>Maintenance Required: {maint_count} of {total}</b> &nbsp; | &nbsp; "
f"✅ <b>Normal Operation: {normal_count} of {total}</b>"
f"</p>",
unsafe_allow_html=True,
)
st.dataframe(
results.style.apply(
lambda row: ["background-color: #ffe0e0" if row["prediction"] == "Maintenance"
else "background-color: #e0ffe0"] * len(row),
axis=1,
),
use_container_width=True,
)
st.download_button(
"Download predictions CSV",
results.to_csv(index=False),
file_name="maintenance_predictions.csv",
mime="text/csv",
)