ManasSharma07 commited on
Commit
5c86435
·
verified ·
1 Parent(s): da79a56

add orbmol model and fix issues with orb models

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +150 -120
src/streamlit_app.py CHANGED
@@ -30,8 +30,11 @@ import sys
30
  import pkg_resources
31
  from ase.vibrations import Vibrations
32
  from mp_api.client import MPRester
 
 
33
  from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
34
  from pymatgen.io.ase import AseAtomsAdaptor
 
35
  import matplotlib.pyplot as plt
36
  mattersim_available = True
37
  if mattersim_available:
@@ -926,14 +929,16 @@ FAIRCHEM_MODELS = {
926
  }
927
  # Define the available ORB models
928
  ORB_MODELS = {
929
- "V3 OMAT Conservative (inf)": "orb-v3-conservative-inf-omat",
930
- "V3 OMAT Conservative (20)": "orb-v3-conservative-20-omat",
931
- "V3 OMAT Direct (inf)": "orb-v3-direct-inf-omat",
932
- "V3 OMAT Direct (20)": "orb-v3-direct-20-omat",
933
- "V3 MPA Conservative (inf)": "orb-v3-conservative-inf-mpa",
934
- "V3 MPA Conservative (20)": "orb-v3-conservative-20-mpa",
935
- "V3 MPA Direct (inf)": "orb-v3-direct-inf-mpa",
936
- "V3 MPA Direct (20)": "orb-v3-direct-20-mpa",
 
 
937
  }
938
  # Define the available MatterSim models
939
  MATTERSIM_MODELS = {
@@ -959,117 +964,11 @@ def get_fairchem_model(selected_model_name, model_path_or_name, device, selected
959
  calc = FAIRChemCalculator(predictor, task_name="omol")
960
  return calc
961
 
962
- # st.sidebar.markdown("## Input Options")
963
- # input_method = st.sidebar.radio("Choose Input Method:",
964
- # ["Select Example", "Upload File", "Paste Content", "Materials Project ID"])
965
- # atoms = None
966
-
967
- # if input_method == "Upload File":
968
- # uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
969
- # if uploaded_file:
970
- # with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
971
- # tmp_file.write(uploaded_file.getvalue())
972
- # tmp_filepath = tmp_file.name
973
- # try:
974
- # atoms = read(tmp_filepath)
975
- # st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
976
- # except Exception as e:
977
- # st.sidebar.error(f"Error loading file: {str(e)}")
978
- # finally:
979
- # if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
980
- # os.unlink(tmp_filepath)
981
-
982
- # elif input_method == "Select Example":
983
- # example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
984
- # if example_name:
985
- # file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
986
- # try:
987
- # atoms = read(file_path)
988
- # st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!")
989
- # except Exception as e:
990
- # st.sidebar.error(f"Error loading example: {str(e)}")
991
-
992
- # elif input_method == "Paste Content":
993
- # file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
994
- # content = st.sidebar.text_area("Paste file content here:", height=200)
995
- # if content:
996
- # try:
997
- # suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
998
- # suffix = suffix_map.get(file_format, ".xyz")
999
- # with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
1000
- # tmp_file.write(content.encode())
1001
- # tmp_filepath = tmp_file.name
1002
- # atoms = read(tmp_filepath)
1003
- # st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
1004
- # except Exception as e:
1005
- # st.sidebar.error(f"Error parsing content: {str(e)}")
1006
- # finally:
1007
- # if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
1008
- # os.unlink(tmp_filepath)
1009
- # elif input_method == "Materials Project ID":
1010
- # # st.sidebar.info("Requires an API Key from [Materials Project](https://next-gen.materialsproject.org/api)")
1011
-
1012
- # # 1. Get API Key (You can also set this as an env variable)
1013
- # # mp_api_key = st.sidebar.text_input("MP API Key", type="password")
1014
- # mp_api_key = os.getenv("MP_API_KEY")
1015
- # # 2. Get Material ID
1016
- # material_id = st.sidebar.text_input("Enter Material ID:", value="mp-149")
1017
-
1018
- # # 3. Choose Cell Type
1019
- # cell_type = st.sidebar.radio("Unit Cell Type:", ['Primitive Cell', 'Conventional Unit Cell'])
1020
-
1021
- # # 1. Initialize the session state variable if it doesn't exist
1022
- # if "atoms" not in st.session_state:
1023
- # st.session_state.atoms = None
1024
-
1025
- # if st.sidebar.button("Fetch Structure"):
1026
- # if not mp_api_key:
1027
- # st.sidebar.error("Please enter your Materials Project API Key.")
1028
- # elif not material_id:
1029
- # st.sidebar.error("Please enter a Material ID.")
1030
- # else:
1031
- # try:
1032
- # with st.spinner(f"Fetching {material_id}..."):
1033
- # # Connect to MP
1034
- # with MPRester(mp_api_key) as mpr:
1035
- # # Get Pymatgen Structure
1036
- # pmg_structure = mpr.get_structure_by_material_id(material_id)
1037
-
1038
- # # Handle Symmetry (Primitive vs Conventional)
1039
- # analyzer = SpacegroupAnalyzer(pmg_structure)
1040
-
1041
- # if cell_type == 'Conventional Unit Cell':
1042
- # final_structure = analyzer.get_conventional_standard_structure()
1043
- # else:
1044
- # final_structure = analyzer.get_primitive_standard_structure()
1045
-
1046
- # # Convert Pymatgen Structure -> ASE Atoms
1047
- # # 2. Store the result in session_state instead of a local variable
1048
- # st.session_state.atoms = AseAtomsAdaptor.get_atoms(final_structure)
1049
-
1050
- # st.sidebar.success(f"Loaded {material_id} ({cell_type}) with {len(st.session_state.atoms)} atoms.")
1051
-
1052
- # except Exception as e:
1053
- # st.sidebar.error(f"Error fetching data: {str(e)}")
1054
- # st.session_state.atoms = None
1055
-
1056
- # # 3. Retrieve the data from session_state for use in the rest of your app
1057
- # atoms = st.session_state.atoms
1058
-
1059
- # # --- Rest of your processing code ---
1060
- # # if atoms is not None:
1061
- # # st.write(f"System is ready. Formula: {atoms.get_chemical_formula()}")
1062
- # # else:
1063
- # # st.info("Please fetch a structure to begin.")
1064
- # if atoms is not None:
1065
- # if not hasattr(atoms, 'info'):
1066
- # atoms.info = {}
1067
- # atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
1068
- # atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
1069
-
1070
  # --- INITIALIZATION (Must be run first) ---
1071
  if "atoms" not in st.session_state:
1072
  st.session_state.atoms = None
 
 
1073
 
1074
  # Reset atoms state if input method changes, to prevent using old data
1075
  # Use a key to track the currently active input method
@@ -1078,7 +977,7 @@ if 'current_input_method' not in st.session_state:
1078
 
1079
  st.sidebar.markdown("## Input Options")
1080
  input_method = st.sidebar.radio("Choose Input Method:",
1081
- ["Select Example", "Upload File", "Paste Content", "Materials Project ID"])
1082
 
1083
  # If the input method changes, clear the loaded structure
1084
  if input_method != st.session_state.current_input_method:
@@ -1168,6 +1067,92 @@ elif input_method == "Paste Content":
1168
  # Clear structure if text area is empty
1169
  st.session_state.atoms = None
1170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171
 
1172
  # --- MATERIALS PROJECT ID ---
1173
  elif input_method == "Materials Project ID":
@@ -1210,6 +1195,52 @@ elif input_method == "Materials Project ID":
1210
  elif not material_id:
1211
  st.sidebar.error("Please enter a Material ID.")
1212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1213
  # ----------------------------------------------------
1214
  # --- FINAL STRUCTURE RETRIEVAL (The persistent structure) ---
1215
  # ----------------------------------------------------
@@ -1224,8 +1255,6 @@ if atoms is not None:
1224
 
1225
  # Display confirmation in the main area (optional, helps the user confirm what's loaded)
1226
  st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
1227
-
1228
- # You can now add your processing buttons here, which will consistently use the 'atoms' variable.
1229
 
1230
  st.sidebar.markdown("## Model Selection")
1231
  if mattersim_available:
@@ -1409,7 +1438,8 @@ if atoms is not None:
1409
  calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
1410
  elif model_type == "ORB":
1411
  # st.write("Setting up ORB calculator...")
1412
- orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
 
1413
  calc = ORBCalculator(orbff, device=device)
1414
  elif model_type == "MatterSim":
1415
  # st.write("Setting up MatterSim calculator...")
 
30
  import pkg_resources
31
  from ase.vibrations import Vibrations
32
  from mp_api.client import MPRester
33
+ import pubchempy as pcp
34
+ from io import StringIO
35
  from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
36
  from pymatgen.io.ase import AseAtomsAdaptor
37
+ from pymatgen.core.structure import Molecule
38
  import matplotlib.pyplot as plt
39
  mattersim_available = True
40
  if mattersim_available:
 
929
  }
930
  # Define the available ORB models
931
  ORB_MODELS = {
932
+ "V3 OMOL Conservative": pretrained.orb_v3_conservative_omol,
933
+ "V3 OMOL Direct": pretrained.orb_v3_direct_omol,
934
+ "V3 OMAT Conservative (inf)": pretrained.orb_v3_conservative_inf_omat,
935
+ "V3 OMAT Conservative (20)": pretrained.orb_v3_conservative_20_omat,
936
+ "V3 OMAT Direct (inf)": pretrained.orb_v3_direct_inf_omat,
937
+ "V3 OMAT Direct (20)": pretrained.orb_v3_direct_20_omat,
938
+ "V3 MPA Conservative (inf)": pretrained.orb_v3_conservative_inf_mpa,
939
+ "V3 MPA Conservative (20)": pretrained.orb_v3_conservative_20_mpa,
940
+ "V3 MPA Direct (inf)": pretrained.orb_v3_direct_inf_mpa,
941
+ "V3 MPA Direct (20)": pretrained.orb_v3_direct_20_mpa,
942
  }
943
  # Define the available MatterSim models
944
  MATTERSIM_MODELS = {
 
964
  calc = FAIRChemCalculator(predictor, task_name="omol")
965
  return calc
966
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
967
  # --- INITIALIZATION (Must be run first) ---
968
  if "atoms" not in st.session_state:
969
  st.session_state.atoms = None
970
+ if "atoms_list" not in st.session_state:
971
+ st.session_state.atoms_list = []
972
 
973
  # Reset atoms state if input method changes, to prevent using old data
974
  # Use a key to track the currently active input method
 
977
 
978
  st.sidebar.markdown("## Input Options")
979
  input_method = st.sidebar.radio("Choose Input Method:",
980
+ ["Select Example", "Upload File", "Paste Content", "Materials Project ID", "PubChem", "Batch Upload"])
981
 
982
  # If the input method changes, clear the loaded structure
983
  if input_method != st.session_state.current_input_method:
 
1067
  # Clear structure if text area is empty
1068
  st.session_state.atoms = None
1069
 
1070
+ # --- PUBCHEM SEARCH MODE ---
1071
+ elif input_method == "PubChem":
1072
+
1073
+
1074
+ st.sidebar.markdown("### Search PubChem")
1075
+
1076
+ query = st.sidebar.text_input("Enter name or formula (e.g., H2O, water, methane):",
1077
+ key="pubchem_query", value="water")
1078
+
1079
+ # Reset atoms if no query
1080
+ if query.strip() == "":
1081
+ st.session_state.atoms = None
1082
+
1083
+ # Step 1: Search PubChem
1084
+ if query and query.strip():
1085
+ # Avoid re-searching if query is unchanged
1086
+ if "pubchem_last_query" not in st.session_state or st.session_state.pubchem_last_query != query:
1087
+ try:
1088
+ with st.spinner("Searching PubChem..."):
1089
+ results = pcp.get_compounds(query, "name") # name OR formula works
1090
+ st.session_state.pubchem_results = results
1091
+ st.session_state.pubchem_last_query = query
1092
+ except Exception as e:
1093
+ st.sidebar.error(f"Error searching PubChem: {str(e)}")
1094
+ st.session_state.pubchem_results = None
1095
+
1096
+ results = st.session_state.get("pubchem_results", [])
1097
+ if results:
1098
+ # Convert to displayable table
1099
+ df = pd.DataFrame(
1100
+ [(c.cid, c.iupac_name, c.molecular_formula, c.molecular_weight, c.isomeric_smiles)
1101
+ for c in results],
1102
+ columns=["CID", "Name", "Formula", "Weight", "SMILES"]
1103
+ )
1104
+ st.sidebar.success(f"Found {len(df)} result(s).")
1105
+ st.sidebar.dataframe(df)
1106
+
1107
+ # Choose a CID
1108
+ cid = st.sidebar.selectbox("Select CID", df["CID"], key="pubchem_cid")
1109
+
1110
+ # Step 2: Retrieve 3D structure for selected CID
1111
+ if cid:
1112
+ if "pubchem_last_cid" not in st.session_state or st.session_state.pubchem_last_cid != cid:
1113
+ try:
1114
+ with st.spinner("Fetching 3D coordinates..."):
1115
+ # Function to format floating-point numbers with alignment
1116
+ def format_number(num, width=10, precision=5):
1117
+ # Handles positive/negative numbers while maintaining alignment
1118
+ return f"{num: {width}.{precision}f}"
1119
+ # CID to XYZ
1120
+ def generate_xyz_coordinates(cid):
1121
+ compound = pcp.Compound.from_cid(cid, record_type='3d')
1122
+ atoms = compound.atoms
1123
+ coords = [(atom.x, atom.y, atom.z) for atom in atoms]
1124
+
1125
+ num_atoms = len(atoms)
1126
+ xyz_text = f"{num_atoms}\n{compound.cid}\n"
1127
+
1128
+ for atom, coord in zip(atoms, coords):
1129
+ atom_symbol = atom.element
1130
+ x, y, z = coord
1131
+ xyz_text += f"{atom_symbol} {format_number(x, precision=8)} {format_number(y, precision=8)} {format_number(z, precision=8)}\n"
1132
+
1133
+ return xyz_text
1134
+ def get_molecule(cid):
1135
+ xyz_str = generate_xyz_coordinates(cid)
1136
+ return Molecule.from_str(xyz_str, fmt='xyz'), xyz_str
1137
+ # Fetch SDF with 3D conformer
1138
+ # sdf_str = pcp.Compound.from_cid(int(cid)).to_sdf()
1139
+ selected_molecule, xyz_str = get_molecule(cid)
1140
+
1141
+ # Convert SDF → ASE Atoms using temporary memory buffer
1142
+ atoms_to_store = read(StringIO(xyz_str), format="xyz")
1143
+
1144
+ atoms_to_store.info["source_name"] = f"PubChem CID {cid}"
1145
+ st.session_state.atoms = atoms_to_store
1146
+ st.session_state.pubchem_last_cid = cid
1147
+
1148
+ st.sidebar.success(f"Loaded PubChem structure with {len(atoms_to_store)} atoms!")
1149
+
1150
+ except Exception as e:
1151
+ st.sidebar.error(f"Unable to retrieve 3D structure: {str(e)}")
1152
+ st.session_state.atoms = None
1153
+ st.session_state.pubchem_last_cid = None
1154
+ else:
1155
+ st.sidebar.info("No PubChem results found.")
1156
 
1157
  # --- MATERIALS PROJECT ID ---
1158
  elif input_method == "Materials Project ID":
 
1195
  elif not material_id:
1196
  st.sidebar.error("Please enter a Material ID.")
1197
 
1198
+ # --- BATCH UPLOAD MULTIPLE FILES ---
1199
+ elif input_method == "Batch Upload":
1200
+
1201
+ uploaded_files = st.sidebar.file_uploader(
1202
+ "Upload multiple structure files",
1203
+ type=["xyz", "cif", "POSCAR", "vasp", "CONTCAR", "mol", "sdf", "tmol", "extxyz"],
1204
+ accept_multiple_files=True
1205
+ )
1206
+
1207
+ # Clear state if no files present
1208
+ if not uploaded_files:
1209
+ st.session_state.atoms_list = []
1210
+ st.session_state.atoms = None
1211
+
1212
+ else:
1213
+ atoms_list = []
1214
+ errors = []
1215
+
1216
+ for file in uploaded_files:
1217
+ try:
1218
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[1]) as tmp:
1219
+ tmp.write(file.getvalue())
1220
+ tmp_path = tmp.name
1221
+
1222
+ atoms_obj = read(tmp_path)
1223
+ atoms_obj.info["source_name"] = file.name
1224
+ atoms_list.append(atoms_obj)
1225
+
1226
+ except Exception as e:
1227
+ errors.append(f"{file.name}: {str(e)}")
1228
+
1229
+ finally:
1230
+ if "tmp_path" in locals() and os.path.exists(tmp_path):
1231
+ os.unlink(tmp_path)
1232
+
1233
+ # Store everything only if at least one success
1234
+ if atoms_list:
1235
+ st.session_state.atoms_list = atoms_list
1236
+ st.session_state.atoms = atoms_list[0] # default: first item
1237
+ st.sidebar.success(f"Loaded {len(atoms_list)} structures successfully!")
1238
+
1239
+ if len(atoms_list) > 1:
1240
+ st.sidebar.info("You can now process them as a batch.")
1241
+
1242
+ if errors:
1243
+ st.sidebar.error("Some files could not be loaded:\n" + "\n".join(errors))
1244
  # ----------------------------------------------------
1245
  # --- FINAL STRUCTURE RETRIEVAL (The persistent structure) ---
1246
  # ----------------------------------------------------
 
1255
 
1256
  # Display confirmation in the main area (optional, helps the user confirm what's loaded)
1257
  st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
 
 
1258
 
1259
  st.sidebar.markdown("## Model Selection")
1260
  if mattersim_available:
 
1438
  calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
1439
  elif model_type == "ORB":
1440
  # st.write("Setting up ORB calculator...")
1441
+ # orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
1442
+ orbff = model_path(device=device, precision=selected_default_dtype)
1443
  calc = ORBCalculator(orbff, device=device)
1444
  elif model_type == "MatterSim":
1445
  # st.write("Setting up MatterSim calculator...")