ShanRaja's picture
Upload folder using huggingface_hub
2bc2558 verified
import streamlit as st
import pandas as pd
import joblib
from huggingface_hub import hf_hub_download
st.set_page_config(page_title="Engine Condition Prediction")
@st.cache_resource
def load_model():
model_path = hf_hub_download(
repo_id="ShanRaja/engine-fault-xgboost",
filename="best_model.joblib"
)
return joblib.load(model_path)
model = load_model()
st.title("Engine Condition Prediction")
# Single row prediction
engine_rpm = st.number_input(
"Engine RPM",
min_value=50,
max_value=2250,
value=791,
step=50
)
lub_oil_pressure = st.number_input(
"Lub Oil Pressure (bar)",
min_value=0.0,
max_value=8.0,
value=3.3,
step=0.1,
format="%.3f"
)
fuel_pressure = st.number_input(
"Fuel Pressure (bar)",
min_value=0.0,
max_value=22.0,
value=6.66,
step=0.1,
format="%.3f"
)
coolant_pressure = st.number_input(
"Coolant Pressure (bar)",
min_value=0.0,
max_value=8.0,
value=2.33,
step=0.1,
format="%.3f"
)
lub_oil_temp = st.number_input(
"Lub Oil Temp (°C)",
min_value=70.0,
max_value=90.0,
value=77.6,
step=0.5,
format="%.2f"
)
coolant_temp = st.number_input(
"Coolant Temp (°C)",
min_value=60.0,
max_value=200.0,
value=78.4,
step=0.5,
format="%.2f"
)
if st.button("Predict"):
input_df = pd.DataFrame(
[[
int(engine_rpm),
lub_oil_pressure,
fuel_pressure,
coolant_pressure,
lub_oil_temp,
coolant_temp
]],
columns=[
"Engine rpm",
"Lub oil pressure",
"Fuel pressure",
"Coolant pressure",
"lub oil temp",
"Coolant temp"
]
)
prediction = model.predict(input_df)[0]
label = "FAULTY" if prediction == 1 else "NORMAL"
st.success(f"Engine Condition: {label}")
# Batch prediction
st.header("Batch Prediction from CSV")
uploaded_file = st.file_uploader("Upload CSV file for batch prediction", type="csv")
if uploaded_file is not None:
input_data = pd.read_csv(uploaded_file)
required_cols = ["Engine rpm","Lub oil pressure","Fuel pressure","Coolant pressure","lub oil temp","Coolant temp"]
if not all(col in input_data.columns for col in required_cols):
st.error(f"CSV must contain columns: {', '.join(required_cols)}")
else:
predictions = model.predict(input_data)
input_data["Engine Condition"] = ["FAULTY" if p == 1 else "NORMAL" for p in predictions]
st.success("Predictions completed!")
st.dataframe(input_data)
csv = input_data.to_csv(index=False).encode("utf-8")
st.download_button(
label="Download Predictions as CSV",
data=csv,
file_name="engine_predictions.csv",
mime="text/csv",
)