Update src/streamlit_app.py
Browse files- src/streamlit_app.py +64 -21
src/streamlit_app.py
CHANGED
|
@@ -10,6 +10,8 @@ import streamlit as st
|
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
import seaborn as sns
|
| 12 |
import joblib
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# ML imports
|
| 15 |
from sklearn.model_selection import train_test_split
|
|
@@ -29,37 +31,36 @@ import shap
|
|
| 29 |
# -------------------------
|
| 30 |
st.set_page_config(page_title="Steel Authority of India Limited (MODEX)", layout="wide")
|
| 31 |
|
| 32 |
-
# Base
|
| 33 |
BASE_DIR = "./"
|
| 34 |
-
|
| 35 |
LOG_DIR = os.path.join(BASE_DIR, "logs")
|
| 36 |
-
DATA_DIR = os.path.join(LOG_DIR, "data_ephemeral")
|
| 37 |
-
os.makedirs(DATA_DIR, exist_ok=True)
|
| 38 |
os.makedirs(LOG_DIR, exist_ok=True)
|
| 39 |
|
| 40 |
-
# Timestamped
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def log(msg: str):
|
| 45 |
-
"""Log message with timestamp to /logs/ for ephemeral HF runs."""
|
| 46 |
with open(LOG_PATH, "a", encoding="utf-8") as f:
|
| 47 |
-
f.write(f"[{datetime.now()
|
| 48 |
print(msg)
|
| 49 |
|
| 50 |
-
log(" Streamlit session started
|
| 51 |
-
log(f"
|
| 52 |
-
log(f"Data Dir = {DATA_DIR} | Log Dir = {LOG_DIR}")
|
| 53 |
-
CSV_PATH = os.path.join(DATA_DIR, "flatfile_universe_advanced.csv")
|
| 54 |
-
META_PATH = os.path.join(DATA_DIR, "feature_metadata_advanced.json")
|
| 55 |
-
ENSEMBLE_ARTIFACT = os.path.join(DATA_DIR, "ensemble_models.joblib")
|
| 56 |
|
| 57 |
|
| 58 |
# Confirm storage mount
|
| 59 |
if os.path.exists("/data"):
|
| 60 |
-
st.sidebar.success(f" Using persistent storage: {
|
| 61 |
else:
|
| 62 |
-
st.sidebar.warning(f" Using ephemeral storage: {
|
| 63 |
|
| 64 |
|
| 65 |
# -------------------------
|
|
@@ -85,7 +86,7 @@ def generate_advanced_flatfile(
|
|
| 85 |
variance_overrides: dict mapping feature name or substring → stddev multiplier
|
| 86 |
"""
|
| 87 |
np.random.seed(random_seed)
|
| 88 |
-
os.makedirs(
|
| 89 |
if variance_overrides is None:
|
| 90 |
variance_overrides = {}
|
| 91 |
|
|
@@ -771,7 +772,7 @@ with tabs[4]:
|
|
| 771 |
st.pyplot(fig)
|
| 772 |
|
| 773 |
# Save trained stack artifacts
|
| 774 |
-
stack_artifact = os.path.join(
|
| 775 |
to_save = {
|
| 776 |
"base_models": {bm["family"]: bm["model"] for bm in base_models if bm["family"] in selected},
|
| 777 |
"meta": meta,
|
|
@@ -904,10 +905,52 @@ in metallurgical AI modeling. Click any title to open the official paper.
|
|
| 904 |
# -------------------------
|
| 905 |
st.markdown("---")
|
| 906 |
st.markdown("**Notes:** This dataset is synthetic and for demo/prototyping. Real plant integration requires NDA, data on-boarding, sensor mapping, and plant safety checks before any control actions.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
# ----- Logs tab
|
| 908 |
tabs.append("View Logs")
|
| 909 |
with tabs[-1]:
|
| 910 |
-
st.subheader("
|
| 911 |
st.markdown("Each run creates a timestamped log file in `/logs/` inside this Space. Use this panel to review run progress and debug output.")
|
| 912 |
|
| 913 |
log_files = sorted(
|
|
@@ -923,4 +966,4 @@ with tabs[-1]:
|
|
| 923 |
with open(path, "r", encoding="utf-8") as f:
|
| 924 |
content = f.read()
|
| 925 |
st.text_area("Log Output", content, height=400)
|
| 926 |
-
st.download_button("
|
|
|
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
import seaborn as sns
|
| 12 |
import joblib
|
| 13 |
+
import zipfile
|
| 14 |
+
import io
|
| 15 |
|
| 16 |
# ML imports
|
| 17 |
from sklearn.model_selection import train_test_split
|
|
|
|
| 31 |
# -------------------------
|
| 32 |
st.set_page_config(page_title="Steel Authority of India Limited (MODEX)", layout="wide")
|
| 33 |
|
| 34 |
+
# Base directory and persistent logs
|
| 35 |
BASE_DIR = "./"
|
|
|
|
| 36 |
LOG_DIR = os.path.join(BASE_DIR, "logs")
|
|
|
|
|
|
|
| 37 |
os.makedirs(LOG_DIR, exist_ok=True)
|
| 38 |
|
| 39 |
+
# Timestamped run subfolder
|
| 40 |
+
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 41 |
+
RUN_DIR = os.path.join(LOG_DIR, f"run_{run_id}")
|
| 42 |
+
os.makedirs(RUN_DIR, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
# File paths for this run
|
| 45 |
+
CSV_PATH = os.path.join(RUN_DIR, "flatfile_universe_advanced.csv")
|
| 46 |
+
META_PATH = os.path.join(RUN_DIR, "feature_metadata_advanced.json")
|
| 47 |
+
ENSEMBLE_ARTIFACT = os.path.join(RUN_DIR, "ensemble_models.joblib")
|
| 48 |
+
LOG_PATH = os.path.join(RUN_DIR, "run.log")
|
| 49 |
|
| 50 |
def log(msg: str):
|
|
|
|
| 51 |
with open(LOG_PATH, "a", encoding="utf-8") as f:
|
| 52 |
+
f.write(f"[{datetime.now():%Y-%m-%d %H:%M:%S}] {msg}\n")
|
| 53 |
print(msg)
|
| 54 |
|
| 55 |
+
log(f" Streamlit session started | run_id={run_id}")
|
| 56 |
+
log(f"Run directory: {RUN_DIR}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
# Confirm storage mount
|
| 60 |
if os.path.exists("/data"):
|
| 61 |
+
st.sidebar.success(f" Using persistent storage | Run directory: {RUN_DIR}")
|
| 62 |
else:
|
| 63 |
+
st.sidebar.warning(f" Using ephemeral storage | Run directory: {RUN_DIR}. Data will be lost on rebuild.")
|
| 64 |
|
| 65 |
|
| 66 |
# -------------------------
|
|
|
|
| 86 |
variance_overrides: dict mapping feature name or substring → stddev multiplier
|
| 87 |
"""
|
| 88 |
np.random.seed(random_seed)
|
| 89 |
+
os.makedirs(RUN_DIR, exist_ok=True)
|
| 90 |
if variance_overrides is None:
|
| 91 |
variance_overrides = {}
|
| 92 |
|
|
|
|
| 772 |
st.pyplot(fig)
|
| 773 |
|
| 774 |
# Save trained stack artifacts
|
| 775 |
+
stack_artifact = os.path.join(RUN_DIR, f"stacked_{use_case.replace(' ', '_')}.joblib")
|
| 776 |
to_save = {
|
| 777 |
"base_models": {bm["family"]: bm["model"] for bm in base_models if bm["family"] in selected},
|
| 778 |
"meta": meta,
|
|
|
|
| 905 |
# -------------------------
|
| 906 |
st.markdown("---")
|
| 907 |
st.markdown("**Notes:** This dataset is synthetic and for demo/prototyping. Real plant integration requires NDA, data on-boarding, sensor mapping, and plant safety checks before any control actions.")
|
| 908 |
+
|
| 909 |
+
# ----- Download tab
|
| 910 |
+
tabs.append("Download Saved Runs")
|
| 911 |
+
with tabs[-1]:
|
| 912 |
+
st.subheader("Reproducibility & Run Exports")
|
| 913 |
+
|
| 914 |
+
run_folders = sorted(
|
| 915 |
+
[f for f in os.listdir(LOG_DIR) if f.startswith("run_")],
|
| 916 |
+
reverse=True
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
if not run_folders:
|
| 920 |
+
st.info("No completed runs found yet.")
|
| 921 |
+
else:
|
| 922 |
+
selected_run = st.selectbox("Select run folder", run_folders, index=0)
|
| 923 |
+
selected_path = os.path.join(LOG_DIR, selected_run)
|
| 924 |
+
|
| 925 |
+
# Show contained files
|
| 926 |
+
files = [
|
| 927 |
+
f for f in os.listdir(selected_path)
|
| 928 |
+
if os.path.isfile(os.path.join(selected_path, f))
|
| 929 |
+
]
|
| 930 |
+
st.write(f"Files in `{selected_run}`:")
|
| 931 |
+
st.write(", ".join(files))
|
| 932 |
+
|
| 933 |
+
# Zip the folder in-memory for download
|
| 934 |
+
zip_buffer = io.BytesIO()
|
| 935 |
+
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
|
| 936 |
+
for root, _, filenames in os.walk(selected_path):
|
| 937 |
+
for fname in filenames:
|
| 938 |
+
file_path = os.path.join(root, fname)
|
| 939 |
+
zipf.write(file_path, arcname=os.path.relpath(file_path, selected_path))
|
| 940 |
+
zip_buffer.seek(0)
|
| 941 |
+
|
| 942 |
+
st.download_button(
|
| 943 |
+
label=f"Download full run ({selected_run}.zip)",
|
| 944 |
+
data=zip_buffer,
|
| 945 |
+
file_name=f"{selected_run}.zip",
|
| 946 |
+
mime="application/zip"
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
|
| 950 |
# ----- Logs tab
|
| 951 |
tabs.append("View Logs")
|
| 952 |
with tabs[-1]:
|
| 953 |
+
st.subheader(" Session & Model Logs")
|
| 954 |
st.markdown("Each run creates a timestamped log file in `/logs/` inside this Space. Use this panel to review run progress and debug output.")
|
| 955 |
|
| 956 |
log_files = sorted(
|
|
|
|
| 966 |
with open(path, "r", encoding="utf-8") as f:
|
| 967 |
content = f.read()
|
| 968 |
st.text_area("Log Output", content, height=400)
|
| 969 |
+
st.download_button(" Download Log", content, file_name=latest)
|