Spaces:
Runtime error
Runtime error
+Faster search -Higher memory
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
|
|
|
| 6 |
import plotly.graph_objs as go
|
| 7 |
import polars as pl
|
| 8 |
from datasets import concatenate_datasets, load_dataset
|
|
@@ -51,27 +52,29 @@ for subset in subsets:
|
|
| 51 |
"functional",
|
| 52 |
],
|
| 53 |
)
|
| 54 |
-
subsets_ds[subset] = dataset["train"]
|
| 55 |
|
| 56 |
-
|
| 57 |
-
# df = pd.concat([x.to_pandas() for x in datasets])
|
| 58 |
-
# train_df = dataset.to_pandas()
|
| 59 |
-
# del dataset
|
| 60 |
|
| 61 |
-
# dataset_element_combination_dict = {}
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
def create_phase_diagram(
|
|
@@ -85,36 +88,31 @@ def create_phase_diagram(
|
|
| 85 |
# Split elements and remove any whitespace
|
| 86 |
element_list = [el.strip() for el in elements.split("-")]
|
| 87 |
|
| 88 |
-
|
| 89 |
-
if functional == "PBE":
|
| 90 |
-
entries_df = subsets_ds["compatible_pbe"].to_pandas()
|
| 91 |
-
# entries_df = train_df[train_df["functional"] == "pbe"]
|
| 92 |
-
elif functional == "PBESol":
|
| 93 |
-
entries_df = subsets_ds["compatible_pbesol"].to_pandas()
|
| 94 |
-
# entries_df = train_df[train_df["functional"] == "pbesol"]
|
| 95 |
-
elif functional == "SCAN":
|
| 96 |
-
entries_df = subsets_ds["compatible_scan"].to_pandas()
|
| 97 |
-
# entries_df = train_df[train_df["functional"] == "scan"]
|
| 98 |
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
isintersection = lambda x: len(set(x).intersection(element_list)) > 0
|
| 105 |
-
entries_df = entries_df[
|
| 106 |
-
[isintersection(l) and isubset(l) for l in entries_df.elements.values.tolist()]
|
| 107 |
]
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
|
| 119 |
# Fetch all entries from the Materials Project database
|
| 120 |
def get_energy_correction(energy_correction, row):
|
|
@@ -153,6 +151,7 @@ def create_phase_diagram(
|
|
| 153 |
try:
|
| 154 |
phase_diagram = PhaseDiagram(entries)
|
| 155 |
except ValueError as e:
|
|
|
|
| 156 |
return go.Figure().add_annotation(text=str(e))
|
| 157 |
|
| 158 |
# Generate plotly figure
|
|
@@ -188,7 +187,10 @@ elements_input = gr.Textbox(
|
|
| 188 |
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
|
| 189 |
# )
|
| 190 |
energy_correction_dropdown = gr.Dropdown(
|
| 191 |
-
choices=[
|
|
|
|
|
|
|
|
|
|
| 192 |
label="Energy correction",
|
| 193 |
)
|
| 194 |
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
|
|
@@ -210,9 +212,7 @@ warning_message += " from <a href='https://chemrxiv.org/engage/api-gateway/chemr
|
|
| 210 |
message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">×</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
|
| 211 |
warning_message
|
| 212 |
)
|
| 213 |
-
message +=
|
| 214 |
-
"<br>Built with <a href='https://pymatgen.org/' target='_blank'>Pymatgen</a> and <a href='https://docs.crystaltoolkit.org/' target='_blank'>Crystal Toolkit</a>.<br>"
|
| 215 |
-
)
|
| 216 |
|
| 217 |
# Create Gradio interface
|
| 218 |
iface = gr.Interface(
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
| 6 |
+
import periodictable
|
| 7 |
import plotly.graph_objs as go
|
| 8 |
import polars as pl
|
| 9 |
from datasets import concatenate_datasets, load_dataset
|
|
|
|
| 52 |
"functional",
|
| 53 |
],
|
| 54 |
)
|
| 55 |
+
subsets_ds[subset] = dataset["train"].to_pandas()
|
| 56 |
|
| 57 |
+
elements_df = {k: subset["elements"] for k, subset in subsets_ds.items()}
|
|
|
|
|
|
|
|
|
|
| 58 |
|
|
|
|
| 59 |
|
| 60 |
+
all_elements = {str(el): i for i, el in enumerate(periodictable.elements)}
|
| 61 |
+
elements_indices = {}
|
| 62 |
+
for subset, df in elements_df.items():
|
| 63 |
+
print("Processing subset: ", subset)
|
| 64 |
+
elements_indices[subset] = np.zeros((len(df), len(all_elements)))
|
| 65 |
+
|
| 66 |
+
def map_elements(row):
|
| 67 |
+
index, xs = row["index"], row["elements"]
|
| 68 |
+
for x in xs:
|
| 69 |
+
elements_indices[subset][index, all_elements[x]] = 1
|
| 70 |
+
|
| 71 |
+
df = df.reset_index().apply(map_elements, axis=1)
|
| 72 |
+
|
| 73 |
+
map_functional = {
|
| 74 |
+
"PBE": "compatible_pbe",
|
| 75 |
+
"PBESol": "compatible_pbesol",
|
| 76 |
+
"SCAN": "compatible_scan",
|
| 77 |
+
}
|
| 78 |
|
| 79 |
|
| 80 |
def create_phase_diagram(
|
|
|
|
| 88 |
# Split elements and remove any whitespace
|
| 89 |
element_list = [el.strip() for el in elements.split("-")]
|
| 90 |
|
| 91 |
+
subset_name = map_functional[functional]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
element_list_vector = np.zeros(len(all_elements))
|
| 94 |
+
for el in element_list:
|
| 95 |
+
element_list_vector[all_elements[el]] = 1
|
| 96 |
|
| 97 |
+
n_elements = elements_indices[subset_name].sum(axis=1)
|
| 98 |
+
n_elements_query = elements_indices[subset_name][
|
| 99 |
+
:, element_list_vector.astype(bool)
|
|
|
|
|
|
|
|
|
|
| 100 |
]
|
| 101 |
|
| 102 |
+
if n_elements_query.shape[1] == 0:
|
| 103 |
+
indices_with_only_elements = []
|
| 104 |
+
else:
|
| 105 |
+
indices_with_only_elements = np.where(
|
| 106 |
+
n_elements_query.sum(axis=1) == n_elements
|
| 107 |
+
)[0]
|
| 108 |
+
|
| 109 |
+
print(indices_with_only_elements)
|
| 110 |
+
|
| 111 |
+
entries_df = subsets_ds[subset_name].loc[indices_with_only_elements]
|
| 112 |
+
|
| 113 |
+
entries_df = entries_df[~entries_df["immutable_id"].isna()]
|
| 114 |
|
| 115 |
+
print(entries_df)
|
| 116 |
|
| 117 |
# Fetch all entries from the Materials Project database
|
| 118 |
def get_energy_correction(energy_correction, row):
|
|
|
|
| 151 |
try:
|
| 152 |
phase_diagram = PhaseDiagram(entries)
|
| 153 |
except ValueError as e:
|
| 154 |
+
print(e)
|
| 155 |
return go.Figure().add_annotation(text=str(e))
|
| 156 |
|
| 157 |
# Generate plotly figure
|
|
|
|
| 187 |
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
|
| 188 |
# )
|
| 189 |
energy_correction_dropdown = gr.Dropdown(
|
| 190 |
+
choices=[
|
| 191 |
+
"The 110 PBE Method",
|
| 192 |
+
"Database specific, or MP2020",
|
| 193 |
+
],
|
| 194 |
label="Energy correction",
|
| 195 |
)
|
| 196 |
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
|
|
|
|
| 212 |
message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">×</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
|
| 213 |
warning_message
|
| 214 |
)
|
| 215 |
+
message += "<br>Built with <a href='https://pymatgen.org/' target='_blank'>Pymatgen</a> and <a href='https://docs.crystaltoolkit.org/' target='_blank'>Crystal Toolkit</a>.<br>"
|
|
|
|
|
|
|
| 216 |
|
| 217 |
# Create Gradio interface
|
| 218 |
iface = gr.Interface(
|