Tingxie commited on
Commit
5622cd7
·
1 Parent(s): 92e8603

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -34
app.py CHANGED
@@ -25,6 +25,7 @@ from gradio.themes.base import Base
25
  from gradio.themes.utils import colors, fonts, sizes
26
  from huggingface_hub import hf_hub_download
27
  import time
 
28
 
29
  class Seafoam(Base):
30
  def __init__(
@@ -252,40 +253,48 @@ def retrieve_similarity_scores( table_name, target_mass,collision_energy, ms2_em
252
  filtered_smiles = cur.fetchall()
253
  similarity_scores = []
254
 
255
- for smile in filtered_smiles:
256
- query = f"""
257
- SELECT low_energy_embedding, median_energy_embedding, high_energy_embedding
258
- FROM {table_name}
259
- WHERE SMILES = ?
260
- """
261
- cur.execute(query, (smile[0],))
262
- row = cur.fetchone()
263
- if row is None:
264
- return None
265
- low_energy_embedding_db = np.array(pickle.loads(row[0]), dtype=np.float64)
266
- median_energy_embedding_db = np.array(pickle.loads(row[1]), dtype=np.float64)
267
- high_energy_embedding_db = np.array(pickle.loads(row[2]), dtype=np.float64)
268
- low_energy_embedding_db,median_energy_embedding_db,high_energy_embedding_db = torch.tensor(low_energy_embedding_db).float(),torch.tensor(median_energy_embedding_db).float(),torch.tensor(high_energy_embedding_db).float()
269
- low_similarity =(ms2_embedding_low @ low_energy_embedding_db.t()).item()
270
- median_similarity = (ms2_embedding_median @ median_energy_embedding_db.t()).item()
271
- high_similarity = (ms2_embedding_high @ high_energy_embedding_db.t()).item()
272
- '''
273
- low_similarity = calculate_cosine_similarity(ms2_embedding_low, low_energy_embedding_db)
274
- median_similarity = calculate_cosine_similarity(ms2_embedding_median, median_energy_embedding_db)
275
- high_similarity = calculate_cosine_similarity(ms2_embedding_high, high_energy_embedding_db)'''
276
- similarity_scores.append((smile, low_similarity, median_similarity, high_similarity))
277
-
278
- weighted_similarity_scores = []
279
- for smile, low_similarity, median_similarity, high_similarity in similarity_scores:
280
- if collision_energy <=15:
281
- weighted_similarity = 0.4 * low_similarity + 0.3 * median_similarity + 0.3 * high_similarity
282
- weighted_similarity_scores.append((smile, weighted_similarity))
283
- elif collision_energy >15 and collision_energy <= 25:
284
- weighted_similarity = 0.3 * low_similarity + 0.4 * median_similarity + 0.3 * high_similarity
285
- weighted_similarity_scores.append((smile, weighted_similarity))
286
- elif collision_energy > 25:
287
- weighted_similarity = 0.2 * low_similarity + 0.3 * median_similarity + 0.5 * high_similarity
288
- weighted_similarity_scores.append((smile, weighted_similarity))
 
 
 
 
 
 
 
 
289
 
290
  weighted_similarity_scores.sort(key=lambda x: x[1], reverse=True)
291
 
 
25
  from gradio.themes.utils import colors, fonts, sizes
26
  from huggingface_hub import hf_hub_download
27
  import time
28
+ import concurrent.futures
29
 
30
  class Seafoam(Base):
31
  def __init__(
 
253
  filtered_smiles = cur.fetchall()
254
  similarity_scores = []
255
 
256
+ query = f"""
257
+ SELECT SMILES, low_energy_embedding, median_energy_embedding, high_energy_embedding
258
+ FROM {table_name}
259
+ WHERE SMILES IN ({','.join(['?']*len(filtered_smiles))})
260
+ """
261
+ cur.execute(query, tuple(s[0] for s in filtered_smiles))
262
+ rows = cur.fetchall()
263
+
264
+ def decode_row(row):
265
+ return (
266
+ row[0], # SMILES
267
+ np.array(pickle.loads(row[1]), dtype=np.float32),
268
+ np.array(pickle.loads(row[2]), dtype=np.float32),
269
+ np.array(pickle.loads(row[3]), dtype=np.float32),
270
+ )
271
+
272
+ with concurrent.futures.ThreadPoolExecutor() as executor:
273
+ results = list(executor.map(decode_row, rows))
274
+
275
+ ms2_embedding_low_np = ms2_embedding_low.numpy()
276
+ ms2_embedding_median_np = ms2_embedding_median.numpy()
277
+ ms2_embedding_high_np = ms2_embedding_high.numpy()
278
+
279
+ similarity_scores = [
280
+ (
281
+ smile,
282
+ np.dot(ms2_embedding_low_np, low_embedding),
283
+ np.dot(ms2_embedding_median_np, median_embedding),
284
+ np.dot(ms2_embedding_high_np, high_embedding),
285
+ )
286
+ for smile, low_embedding, median_embedding, high_embedding in results
287
+ ]
288
+
289
+ collision_weights = np.array([
290
+ [0.4, 0.3, 0.3] if collision_energy <= 15 else
291
+ [0.3, 0.4, 0.3] if collision_energy <= 25 else
292
+ [0.2, 0.3, 0.5]
293
+ ])
294
+
295
+ similarity_array = np.array([[low, median, high] for _, low, median, high in similarity_scores])
296
+ weighted_similarities = similarity_array @ collision_weights.T
297
+ weighted_similarity_scores = [(smile, weighted) for (smile, _low, _med, _high), weighted in zip(similarity_scores, weighted_similarities)]
298
 
299
  weighted_similarity_scores.sort(key=lambda x: x[1], reverse=True)
300