Tingxie commited on
Commit
8c09b4a
·
1 Parent(s): b1c263e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -34
app.py CHANGED
@@ -32,6 +32,7 @@ from gradio.themes.base import Base
32
  from gradio.themes.utils import colors, fonts, sizes
33
  from huggingface_hub import hf_hub_download
34
  import time
 
35
 
36
  class Seafoam(Base):
37
  def __init__(
@@ -260,40 +261,48 @@ def retrieve_similarity_scores( table_name, target_mass,collision_energy, ms2_em
260
  filtered_smiles = cur.fetchall()
261
  similarity_scores = []
262
 
263
- for smile in filtered_smiles:
264
- query = f"""
265
- SELECT low_energy_embedding, median_energy_embedding, high_energy_embedding
266
- FROM {table_name}
267
- WHERE SMILES = ?
268
- """
269
- cur.execute(query, (smile[0],))
270
- row = cur.fetchone()
271
- if row is None:
272
- return None
273
- low_energy_embedding_db = np.array(pickle.loads(row[0]), dtype=np.float64)
274
- median_energy_embedding_db = np.array(pickle.loads(row[1]), dtype=np.float64)
275
- high_energy_embedding_db = np.array(pickle.loads(row[2]), dtype=np.float64)
276
- low_energy_embedding_db,median_energy_embedding_db,high_energy_embedding_db = torch.tensor(low_energy_embedding_db).float(),torch.tensor(median_energy_embedding_db).float(),torch.tensor(high_energy_embedding_db).float()
277
- low_similarity =(ms2_embedding_low @ low_energy_embedding_db.t()).item()
278
- median_similarity = (ms2_embedding_median @ median_energy_embedding_db.t()).item()
279
- high_similarity = (ms2_embedding_high @ high_energy_embedding_db.t()).item()
280
- '''
281
- low_similarity = calculate_cosine_similarity(ms2_embedding_low, low_energy_embedding_db)
282
- median_similarity = calculate_cosine_similarity(ms2_embedding_median, median_energy_embedding_db)
283
- high_similarity = calculate_cosine_similarity(ms2_embedding_high, high_energy_embedding_db)'''
284
- similarity_scores.append((smile, low_similarity, median_similarity, high_similarity))
285
-
286
- weighted_similarity_scores = []
287
- for smile, low_similarity, median_similarity, high_similarity in similarity_scores:
288
- if collision_energy <=15:
289
- weighted_similarity = 0.4 * low_similarity + 0.3 * median_similarity + 0.3 * high_similarity
290
- weighted_similarity_scores.append((smile, weighted_similarity))
291
- elif collision_energy >15 and collision_energy <= 25:
292
- weighted_similarity = 0.3 * low_similarity + 0.4 * median_similarity + 0.3 * high_similarity
293
- weighted_similarity_scores.append((smile, weighted_similarity))
294
- elif collision_energy > 25:
295
- weighted_similarity = 0.2 * low_similarity + 0.3 * median_similarity + 0.5 * high_similarity
296
- weighted_similarity_scores.append((smile, weighted_similarity))
 
 
 
 
 
 
 
 
297
 
298
  weighted_similarity_scores.sort(key=lambda x: x[1], reverse=True)
299
 
 
32
  from gradio.themes.utils import colors, fonts, sizes
33
  from huggingface_hub import hf_hub_download
34
  import time
35
+ import concurrent.futures
36
 
37
  class Seafoam(Base):
38
  def __init__(
 
261
  filtered_smiles = cur.fetchall()
262
  similarity_scores = []
263
 
264
+ query = f"""
265
+ SELECT SMILES, low_energy_embedding, median_energy_embedding, high_energy_embedding
266
+ FROM {table_name}
267
+ WHERE SMILES IN ({','.join(['?']*len(filtered_smiles))})
268
+ """
269
+ cur.execute(query, tuple(s[0] for s in filtered_smiles))
270
+ rows = cur.fetchall()
271
+
272
+ def decode_row(row):
273
+ return (
274
+ row[0], # SMILES
275
+ np.array(pickle.loads(row[1]), dtype=np.float32),
276
+ np.array(pickle.loads(row[2]), dtype=np.float32),
277
+ np.array(pickle.loads(row[3]), dtype=np.float32),
278
+ )
279
+
280
+ with concurrent.futures.ThreadPoolExecutor() as executor:
281
+ results = list(executor.map(decode_row, rows))
282
+
283
+ ms2_embedding_low_np = ms2_embedding_low.numpy()
284
+ ms2_embedding_median_np = ms2_embedding_median.numpy()
285
+ ms2_embedding_high_np = ms2_embedding_high.numpy()
286
+
287
+ similarity_scores = [
288
+ (
289
+ smile,
290
+ np.dot(ms2_embedding_low_np, low_embedding),
291
+ np.dot(ms2_embedding_median_np, median_embedding),
292
+ np.dot(ms2_embedding_high_np, high_embedding),
293
+ )
294
+ for smile, low_embedding, median_embedding, high_embedding in results
295
+ ]
296
+
297
+ collision_weights = np.array([
298
+ [0.4, 0.3, 0.3] if collision_energy <= 15 else
299
+ [0.3, 0.4, 0.3] if collision_energy <= 25 else
300
+ [0.2, 0.3, 0.5]
301
+ ])
302
+
303
+ similarity_array = np.array([[low, median, high] for _, low, median, high in similarity_scores])
304
+ weighted_similarities = similarity_array @ collision_weights.T # 矩阵乘法
305
+ weighted_similarity_scores = [(smile, weighted) for (smile, _low, _med, _high), weighted in zip(similarity_scores, weighted_similarities)]
306
 
307
  weighted_similarity_scores.sort(key=lambda x: x[1], reverse=True)
308