Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -32,6 +32,7 @@ from gradio.themes.base import Base
|
|
| 32 |
from gradio.themes.utils import colors, fonts, sizes
|
| 33 |
from huggingface_hub import hf_hub_download
|
| 34 |
import time
|
|
|
|
| 35 |
|
| 36 |
class Seafoam(Base):
|
| 37 |
def __init__(
|
|
@@ -260,40 +261,48 @@ def retrieve_similarity_scores( table_name, target_mass,collision_energy, ms2_em
|
|
| 260 |
filtered_smiles = cur.fetchall()
|
| 261 |
similarity_scores = []
|
| 262 |
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
weighted_similarity_scores.sort(key=lambda x: x[1], reverse=True)
|
| 299 |
|
|
|
|
| 32 |
from gradio.themes.utils import colors, fonts, sizes
|
| 33 |
from huggingface_hub import hf_hub_download
|
| 34 |
import time
|
| 35 |
+
import concurrent.futures
|
| 36 |
|
| 37 |
class Seafoam(Base):
|
| 38 |
def __init__(
|
|
|
|
| 261 |
filtered_smiles = cur.fetchall()
|
| 262 |
similarity_scores = []
|
| 263 |
|
| 264 |
+
query = f"""
|
| 265 |
+
SELECT SMILES, low_energy_embedding, median_energy_embedding, high_energy_embedding
|
| 266 |
+
FROM {table_name}
|
| 267 |
+
WHERE SMILES IN ({','.join(['?']*len(filtered_smiles))})
|
| 268 |
+
"""
|
| 269 |
+
cur.execute(query, tuple(s[0] for s in filtered_smiles))
|
| 270 |
+
rows = cur.fetchall()
|
| 271 |
+
|
| 272 |
+
def decode_row(row):
|
| 273 |
+
return (
|
| 274 |
+
row[0], # SMILES
|
| 275 |
+
np.array(pickle.loads(row[1]), dtype=np.float32),
|
| 276 |
+
np.array(pickle.loads(row[2]), dtype=np.float32),
|
| 277 |
+
np.array(pickle.loads(row[3]), dtype=np.float32),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 281 |
+
results = list(executor.map(decode_row, rows))
|
| 282 |
+
|
| 283 |
+
ms2_embedding_low_np = ms2_embedding_low.numpy()
|
| 284 |
+
ms2_embedding_median_np = ms2_embedding_median.numpy()
|
| 285 |
+
ms2_embedding_high_np = ms2_embedding_high.numpy()
|
| 286 |
+
|
| 287 |
+
similarity_scores = [
|
| 288 |
+
(
|
| 289 |
+
smile,
|
| 290 |
+
np.dot(ms2_embedding_low_np, low_embedding),
|
| 291 |
+
np.dot(ms2_embedding_median_np, median_embedding),
|
| 292 |
+
np.dot(ms2_embedding_high_np, high_embedding),
|
| 293 |
+
)
|
| 294 |
+
for smile, low_embedding, median_embedding, high_embedding in results
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
collision_weights = np.array([
|
| 298 |
+
[0.4, 0.3, 0.3] if collision_energy <= 15 else
|
| 299 |
+
[0.3, 0.4, 0.3] if collision_energy <= 25 else
|
| 300 |
+
[0.2, 0.3, 0.5]
|
| 301 |
+
])
|
| 302 |
+
|
| 303 |
+
similarity_array = np.array([[low, median, high] for _, low, median, high in similarity_scores])
|
| 304 |
+
weighted_similarities = similarity_array @ collision_weights.T # 矩阵乘法
|
| 305 |
+
weighted_similarity_scores = [(smile, weighted) for (smile, _low, _med, _high), weighted in zip(similarity_scores, weighted_similarities)]
|
| 306 |
|
| 307 |
weighted_similarity_scores.sort(key=lambda x: x[1], reverse=True)
|
| 308 |
|