Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -782,7 +782,7 @@ with tab2:
|
|
| 782 |
with tab3:
|
| 783 |
st.header("π Load Pre-trained Models")
|
| 784 |
|
| 785 |
-
st.markdown("Upload Bayesian Network models (.pkl files) to view their structures.")
|
| 786 |
st.markdown("**Maximum: 2 models**")
|
| 787 |
st.markdown("---")
|
| 788 |
|
|
@@ -792,16 +792,17 @@ with tab3:
|
|
| 792 |
uploaded_model = None
|
| 793 |
else:
|
| 794 |
uploaded_model = st.file_uploader(
|
| 795 |
-
"Upload model file (.pkl)",
|
| 796 |
-
type=['pkl'],
|
| 797 |
-
help="Upload a .pkl file containing a Bayesian Network model"
|
| 798 |
)
|
| 799 |
|
| 800 |
if uploaded_model:
|
| 801 |
col_load1, col_load2 = st.columns([3, 1])
|
| 802 |
|
| 803 |
with col_load1:
|
| 804 |
-
|
|
|
|
| 805 |
|
| 806 |
with col_load2:
|
| 807 |
if st.button("π Load Model", type="primary", use_container_width=True):
|
|
@@ -810,28 +811,68 @@ with tab3:
|
|
| 810 |
import pickle
|
| 811 |
import tempfile
|
| 812 |
import os
|
|
|
|
|
|
|
|
|
|
| 813 |
|
| 814 |
-
|
| 815 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmp_file:
|
| 816 |
-
tmp_file.write(uploaded_model.read())
|
| 817 |
-
tmp_path = tmp_file.name
|
| 818 |
|
| 819 |
-
#
|
| 820 |
-
|
| 821 |
-
|
|
|
|
|
|
|
|
|
|
| 822 |
|
| 823 |
-
|
|
|
|
| 824 |
|
| 825 |
-
|
| 826 |
-
from pgmpy.models import BayesianNetwork
|
| 827 |
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 835 |
|
| 836 |
if model:
|
| 837 |
# Store model
|
|
|
|
| 782 |
with tab3:
|
| 783 |
st.header("π Load Pre-trained Models")
|
| 784 |
|
| 785 |
+
st.markdown("Upload Bayesian Network models (.pkl or .json files) to view their structures.")
|
| 786 |
st.markdown("**Maximum: 2 models**")
|
| 787 |
st.markdown("---")
|
| 788 |
|
|
|
|
| 792 |
uploaded_model = None
|
| 793 |
else:
|
| 794 |
uploaded_model = st.file_uploader(
|
| 795 |
+
"Upload model file (.pkl or .json)",
|
| 796 |
+
type=['pkl', 'json'],
|
| 797 |
+
help="Upload a .pkl or .json file containing a Bayesian Network model"
|
| 798 |
)
|
| 799 |
|
| 800 |
if uploaded_model:
|
| 801 |
col_load1, col_load2 = st.columns([3, 1])
|
| 802 |
|
| 803 |
with col_load1:
|
| 804 |
+
file_type = "PKL" if uploaded_model.name.endswith('.pkl') else "JSON"
|
| 805 |
+
st.info(f"π File: **{uploaded_model.name}** ({file_type})")
|
| 806 |
|
| 807 |
with col_load2:
|
| 808 |
if st.button("π Load Model", type="primary", use_container_width=True):
|
|
|
|
| 811 |
import pickle
|
| 812 |
import tempfile
|
| 813 |
import os
|
| 814 |
+
import json
|
| 815 |
+
from pgmpy.models import BayesianNetwork
|
| 816 |
+
from pgmpy.factors.discrete import TabularCPD
|
| 817 |
|
| 818 |
+
model = None
|
|
|
|
|
|
|
|
|
|
| 819 |
|
| 820 |
+
# Check file type
|
| 821 |
+
if uploaded_model.name.endswith('.pkl'):
|
| 822 |
+
# Load PKL file
|
| 823 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmp_file:
|
| 824 |
+
tmp_file.write(uploaded_model.read())
|
| 825 |
+
tmp_path = tmp_file.name
|
| 826 |
|
| 827 |
+
with open(tmp_path, 'rb') as f:
|
| 828 |
+
model_data = pickle.load(f)
|
| 829 |
|
| 830 |
+
os.unlink(tmp_path)
|
|
|
|
| 831 |
|
| 832 |
+
if isinstance(model_data, BayesianNetwork):
|
| 833 |
+
model = model_data
|
| 834 |
+
elif isinstance(model_data, dict) and 'model' in model_data:
|
| 835 |
+
model = model_data['model']
|
| 836 |
+
else:
|
| 837 |
+
st.error("β Invalid PKL model format")
|
| 838 |
+
|
| 839 |
+
elif uploaded_model.name.endswith('.json'):
|
| 840 |
+
# Load JSON file
|
| 841 |
+
json_data = json.loads(uploaded_model.read().decode('utf-8'))
|
| 842 |
+
|
| 843 |
+
# Reconstruct BayesianNetwork from JSON
|
| 844 |
+
if 'model' in json_data:
|
| 845 |
+
model_data = json_data['model']
|
| 846 |
+
|
| 847 |
+
# Create network with edges
|
| 848 |
+
model = BayesianNetwork(model_data['edges'])
|
| 849 |
+
|
| 850 |
+
# Add CPDs
|
| 851 |
+
for node, cpd_data in model_data['cpds'].items():
|
| 852 |
+
# Ensure values is 2D array
|
| 853 |
+
import numpy as np
|
| 854 |
+
values = cpd_data['values']
|
| 855 |
+
if isinstance(values, list) and len(values) > 0:
|
| 856 |
+
if not isinstance(values[0], list):
|
| 857 |
+
# 1D array -> convert to 2D
|
| 858 |
+
values = np.array(values).reshape(-1, 1)
|
| 859 |
+
|
| 860 |
+
cpd = TabularCPD(
|
| 861 |
+
variable=cpd_data['variable'],
|
| 862 |
+
variable_card=cpd_data['cardinality'][0],
|
| 863 |
+
values=values,
|
| 864 |
+
evidence=cpd_data['parents'] or None,
|
| 865 |
+
evidence_card=cpd_data['cardinality'][1:] or None,
|
| 866 |
+
state_names=cpd_data.get('state_names', None)
|
| 867 |
+
)
|
| 868 |
+
model.add_cpds(cpd)
|
| 869 |
+
|
| 870 |
+
# Validate model
|
| 871 |
+
if not model.check_model():
|
| 872 |
+
st.error("β Invalid model structure")
|
| 873 |
+
model = None
|
| 874 |
+
else:
|
| 875 |
+
st.error("β Invalid JSON format: missing 'model' key")
|
| 876 |
|
| 877 |
if model:
|
| 878 |
# Store model
|