Tingxie commited on
Commit
de76273
·
1 Parent(s): 5622cd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -42
app.py CHANGED
@@ -253,48 +253,40 @@ def retrieve_similarity_scores( table_name, target_mass,collision_energy, ms2_em
253
  filtered_smiles = cur.fetchall()
254
  similarity_scores = []
255
 
256
- query = f"""
257
- SELECT SMILES, low_energy_embedding, median_energy_embedding, high_energy_embedding
258
- FROM {table_name}
259
- WHERE SMILES IN ({','.join(['?']*len(filtered_smiles))})
260
- """
261
- cur.execute(query, tuple(s[0] for s in filtered_smiles))
262
- rows = cur.fetchall()
263
-
264
- def decode_row(row):
265
- return (
266
- row[0], # SMILES
267
- np.array(pickle.loads(row[1]), dtype=np.float32),
268
- np.array(pickle.loads(row[2]), dtype=np.float32),
269
- np.array(pickle.loads(row[3]), dtype=np.float32),
270
- )
271
-
272
- with concurrent.futures.ThreadPoolExecutor() as executor:
273
- results = list(executor.map(decode_row, rows))
274
-
275
- ms2_embedding_low_np = ms2_embedding_low.numpy()
276
- ms2_embedding_median_np = ms2_embedding_median.numpy()
277
- ms2_embedding_high_np = ms2_embedding_high.numpy()
278
-
279
- similarity_scores = [
280
- (
281
- smile,
282
- np.dot(ms2_embedding_low_np, low_embedding),
283
- np.dot(ms2_embedding_median_np, median_embedding),
284
- np.dot(ms2_embedding_high_np, high_embedding),
285
- )
286
- for smile, low_embedding, median_embedding, high_embedding in results
287
- ]
288
-
289
- collision_weights = np.array([
290
- [0.4, 0.3, 0.3] if collision_energy <= 15 else
291
- [0.3, 0.4, 0.3] if collision_energy <= 25 else
292
- [0.2, 0.3, 0.5]
293
- ])
294
-
295
- similarity_array = np.array([[low, median, high] for _, low, median, high in similarity_scores])
296
- weighted_similarities = similarity_array @ collision_weights.T
297
- weighted_similarity_scores = [(smile, weighted) for (smile, _low, _med, _high), weighted in zip(similarity_scores, weighted_similarities)]
298
 
299
  weighted_similarity_scores.sort(key=lambda x: x[1], reverse=True)
300
 
 
253
  filtered_smiles = cur.fetchall()
254
  similarity_scores = []
255
 
256
+ for smile in filtered_smiles:
257
+ query = f"""
258
+ SELECT low_energy_embedding, median_energy_embedding, high_energy_embedding
259
+ FROM {table_name}
260
+ WHERE SMILES = ?
261
+ """
262
+ cur.execute(query, (smile[0],))
263
+ row = cur.fetchone()
264
+ if row is None:
265
+ return None
266
+ low_energy_embedding_db = np.array(pickle.loads(row[0]), dtype=np.float64)
267
+ median_energy_embedding_db = np.array(pickle.loads(row[1]), dtype=np.float64)
268
+ high_energy_embedding_db = np.array(pickle.loads(row[2]), dtype=np.float64)
269
+ low_energy_embedding_db,median_energy_embedding_db,high_energy_embedding_db = torch.tensor(low_energy_embedding_db).float(),torch.tensor(median_energy_embedding_db).float(),torch.tensor(high_energy_embedding_db).float()
270
+ low_similarity =(ms2_embedding_low @ low_energy_embedding_db.t()).item()
271
+ median_similarity = (ms2_embedding_median @ median_energy_embedding_db.t()).item()
272
+ high_similarity = (ms2_embedding_high @ high_energy_embedding_db.t()).item()
273
+ '''
274
+ low_similarity = calculate_cosine_similarity(ms2_embedding_low, low_energy_embedding_db)
275
+ median_similarity = calculate_cosine_similarity(ms2_embedding_median, median_energy_embedding_db)
276
+ high_similarity = calculate_cosine_similarity(ms2_embedding_high, high_energy_embedding_db)'''
277
+ similarity_scores.append((smile, low_similarity, median_similarity, high_similarity))
278
+
279
+ weighted_similarity_scores = []
280
+ for smile, low_similarity, median_similarity, high_similarity in similarity_scores:
281
+ if collision_energy <=15:
282
+ weighted_similarity = 0.4 * low_similarity + 0.3 * median_similarity + 0.3 * high_similarity
283
+ weighted_similarity_scores.append((smile, weighted_similarity))
284
+ elif collision_energy >15 and collision_energy <= 25:
285
+ weighted_similarity = 0.3 * low_similarity + 0.4 * median_similarity + 0.3 * high_similarity
286
+ weighted_similarity_scores.append((smile, weighted_similarity))
287
+ elif collision_energy > 25:
288
+ weighted_similarity = 0.2 * low_similarity + 0.3 * median_similarity + 0.5 * high_similarity
289
+ weighted_similarity_scores.append((smile, weighted_similarity))
 
 
 
 
 
 
 
 
290
 
291
  weighted_similarity_scores.sort(key=lambda x: x[1], reverse=True)
292