Donlagon007 commited on
Commit
d446331
Β·
verified Β·
1 Parent(s): a513f8a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -22
app.py CHANGED
@@ -782,7 +782,7 @@ with tab2:
782
  with tab3:
783
  st.header("πŸ“‚ Load Pre-trained Models")
784
 
785
- st.markdown("Upload Bayesian Network models (.pkl files) to view their structures.")
786
  st.markdown("**Maximum: 2 models**")
787
  st.markdown("---")
788
 
@@ -792,16 +792,17 @@ with tab3:
792
  uploaded_model = None
793
  else:
794
  uploaded_model = st.file_uploader(
795
- "Upload model file (.pkl)",
796
- type=['pkl'],
797
- help="Upload a .pkl file containing a Bayesian Network model"
798
  )
799
 
800
  if uploaded_model:
801
  col_load1, col_load2 = st.columns([3, 1])
802
 
803
  with col_load1:
804
- st.info(f"πŸ“„ File: **{uploaded_model.name}**")
 
805
 
806
  with col_load2:
807
  if st.button("πŸ”„ Load Model", type="primary", use_container_width=True):
@@ -810,28 +811,68 @@ with tab3:
810
  import pickle
811
  import tempfile
812
  import os
 
 
 
813
 
814
- # Save to temp file
815
- with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmp_file:
816
- tmp_file.write(uploaded_model.read())
817
- tmp_path = tmp_file.name
818
 
819
- # Load model
820
- with open(tmp_path, 'rb') as f:
821
- model_data = pickle.load(f)
 
 
 
822
 
823
- os.unlink(tmp_path)
 
824
 
825
- # Extract model
826
- from pgmpy.models import BayesianNetwork
827
 
828
- if isinstance(model_data, BayesianNetwork):
829
- model = model_data
830
- elif isinstance(model_data, dict) and 'model' in model_data:
831
- model = model_data['model']
832
- else:
833
- st.error("❌ Invalid model format")
834
- model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
835
 
836
  if model:
837
  # Store model
 
782
  with tab3:
783
  st.header("πŸ“‚ Load Pre-trained Models")
784
 
785
+ st.markdown("Upload Bayesian Network models (.pkl or .json files) to view their structures.")
786
  st.markdown("**Maximum: 2 models**")
787
  st.markdown("---")
788
 
 
792
  uploaded_model = None
793
  else:
794
  uploaded_model = st.file_uploader(
795
+ "Upload model file (.pkl or .json)",
796
+ type=['pkl', 'json'],
797
+ help="Upload a .pkl or .json file containing a Bayesian Network model"
798
  )
799
 
800
  if uploaded_model:
801
  col_load1, col_load2 = st.columns([3, 1])
802
 
803
  with col_load1:
804
+ file_type = "PKL" if uploaded_model.name.endswith('.pkl') else "JSON"
805
+ st.info(f"πŸ“„ File: **{uploaded_model.name}** ({file_type})")
806
 
807
  with col_load2:
808
  if st.button("πŸ”„ Load Model", type="primary", use_container_width=True):
 
811
  import pickle
812
  import tempfile
813
  import os
814
+ import json
815
+ from pgmpy.models import BayesianNetwork
816
+ from pgmpy.factors.discrete import TabularCPD
817
 
818
+ model = None
 
 
 
819
 
820
+ # Check file type
821
+ if uploaded_model.name.endswith('.pkl'):
822
+ # Load PKL file
823
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmp_file:
824
+ tmp_file.write(uploaded_model.read())
825
+ tmp_path = tmp_file.name
826
 
827
+ with open(tmp_path, 'rb') as f:
828
+ model_data = pickle.load(f)
829
 
830
+ os.unlink(tmp_path)
 
831
 
832
+ if isinstance(model_data, BayesianNetwork):
833
+ model = model_data
834
+ elif isinstance(model_data, dict) and 'model' in model_data:
835
+ model = model_data['model']
836
+ else:
837
+ st.error("❌ Invalid PKL model format")
838
+
839
+ elif uploaded_model.name.endswith('.json'):
840
+ # Load JSON file
841
+ json_data = json.loads(uploaded_model.read().decode('utf-8'))
842
+
843
+ # Reconstruct BayesianNetwork from JSON
844
+ if 'model' in json_data:
845
+ model_data = json_data['model']
846
+
847
+ # Create network with edges
848
+ model = BayesianNetwork(model_data['edges'])
849
+
850
+ # Add CPDs
851
+ for node, cpd_data in model_data['cpds'].items():
852
+ # Ensure values is 2D array
853
+ import numpy as np
854
+ values = cpd_data['values']
855
+ if isinstance(values, list) and len(values) > 0:
856
+ if not isinstance(values[0], list):
857
+ # 1D array -> convert to 2D
858
+ values = np.array(values).reshape(-1, 1)
859
+
860
+ cpd = TabularCPD(
861
+ variable=cpd_data['variable'],
862
+ variable_card=cpd_data['cardinality'][0],
863
+ values=values,
864
+ evidence=cpd_data['parents'] or None,
865
+ evidence_card=cpd_data['cardinality'][1:] or None,
866
+ state_names=cpd_data.get('state_names', None)
867
+ )
868
+ model.add_cpds(cpd)
869
+
870
+ # Validate model
871
+ if not model.check_model():
872
+ st.error("❌ Invalid model structure")
873
+ model = None
874
+ else:
875
+ st.error("❌ Invalid JSON format: missing 'model' key")
876
 
877
  if model:
878
  # Store model