EricBoi commited on
Commit
949746a
·
1 Parent(s): 52c9ab4
Files changed (1) hide show
  1. app.py +20 -5
app.py CHANGED
@@ -50,6 +50,15 @@ DCM1 = MessagePassingModel(
50
  n_dcm=1,
51
  )
52
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
  from rdkit import Chem
@@ -80,7 +89,8 @@ def get_grid_points(coordinates):
80
  return grid_points
81
 
82
 
83
- test_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
 
84
 
85
  smiles = 'C1NCCCC1'
86
 
@@ -137,19 +147,24 @@ errors_train = []
137
  batch = psi4_test_batches[batchID]
138
 
139
  #mono, dipo = apply_model(DCM1, test_weights, batch, batch_size)
140
- dcm1results = plot_model(DCM1, test_weights, batch, batch_size, 1, plot=False)
 
141
 
142
- dipo = dcm1results["dipo"]
143
- mono = dcm1results["mono"]
144
  atoms = dcm1results["atoms"]
145
  dcmol = dcm1results["dcmol"]
 
146
 
147
 
148
 
149
-
150
  output = StringIO()
151
  (atoms+dcmol).write(output, format="html")
152
  data = output.getvalue()
 
153
 
 
 
 
154
  components.html(data, width=1000, height=1000)
155
 
 
 
50
  n_dcm=1,
51
  )
52
 
53
+ # Create models
54
+ DCM2 = MessagePassingModel(
55
+ features=features,
56
+ max_degree=max_degree,
57
+ num_iterations=num_iterations,
58
+ num_basis_functions=num_basis_functions,
59
+ cutoff=cutoff,
60
+ n_dcm=2,
61
+ )
62
 
63
 
64
  from rdkit import Chem
 
89
  return grid_points
90
 
91
 
92
+ dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
93
+ dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl")
94
 
95
  smiles = 'C1NCCCC1'
96
 
 
147
  batch = psi4_test_batches[batchID]
148
 
149
  #mono, dipo = apply_model(DCM1, test_weights, batch, batch_size)
150
+ dcm1results = plot_model(DCM1, dcm1_weights, batch, batch_size, 1, plot=False)
151
+ dcm2results = plot_model(DCM2, dcm2_weights, batch, batch_size, 1, plot=False)
152
 
 
 
153
  atoms = dcm1results["atoms"]
154
  dcmol = dcm1results["dcmol"]
155
+ dcmol2 = dcm2results["dcmol"]
156
 
157
 
158
 
159
+ st.write("Click M to see the distributed charges")
160
  output = StringIO()
161
  (atoms+dcmol).write(output, format="html")
162
  data = output.getvalue()
163
+ components.html(data, width=1000, height=1000)
164
 
165
+ output = StringIO()
166
+ (atoms+dcmol2).write(output, format="html")
167
+ data = output.getvalue()
168
  components.html(data, width=1000, height=1000)
169
 
170
+