Update app.py
Browse files
app.py
CHANGED
|
@@ -18,7 +18,7 @@ from rdkit import Chem
|
|
| 18 |
from rdkit.Chem import Draw
|
| 19 |
|
| 20 |
from evaluator import Evaluator
|
| 21 |
-
|
| 22 |
|
| 23 |
# Load the CSV data
|
| 24 |
known_labels = pd.read_csv('data/known_labels.csv')
|
|
@@ -55,18 +55,18 @@ def random_properties():
|
|
| 55 |
|
| 56 |
def load_model(model_choice):
|
| 57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
-
|
| 59 |
#### test
|
| 60 |
-
from graph_decoder.diffusion_model import GraphDiT
|
| 61 |
|
| 62 |
-
model_config_path = f"model_labeled/config.yaml"
|
| 63 |
-
data_info_path = f"model_labeled/data.meta.json"
|
| 64 |
-
model = GraphDiT(
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
)
|
| 70 |
### test
|
| 71 |
return (model, device)
|
| 72 |
|
|
|
|
| 18 |
from rdkit.Chem import Draw
|
| 19 |
|
| 20 |
from evaluator import Evaluator
|
| 21 |
+
from loader import load_graph_decoder
|
| 22 |
|
| 23 |
# Load the CSV data
|
| 24 |
known_labels = pd.read_csv('data/known_labels.csv')
|
|
|
|
| 55 |
|
| 56 |
def load_model(model_choice):
|
| 57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
model = load_graph_decoder(path=model_choice)
|
| 59 |
#### test
|
| 60 |
+
# from graph_decoder.diffusion_model import GraphDiT
|
| 61 |
|
| 62 |
+
# model_config_path = f"model_labeled/config.yaml"
|
| 63 |
+
# data_info_path = f"model_labeled/data.meta.json"
|
| 64 |
+
# model = GraphDiT(
|
| 65 |
+
# model_config_path=model_config_path,
|
| 66 |
+
# data_info_path=data_info_path,
|
| 67 |
+
# # model_dtype=torch.float16,
|
| 68 |
+
# model_dtype=torch.float32,
|
| 69 |
+
# )
|
| 70 |
### test
|
| 71 |
return (model, device)
|
| 72 |
|