Ana Sanchez commited on
Commit
fd0260e
·
1 Parent(s): a402a05

Update app

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -411,8 +411,8 @@ def molecules_from_image():
411
  print((top_probs, top_labels))
412
 
413
  def images_from_molecule():
414
- #st.markdown("Enter a query molecule in [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) format",)
415
- smiles = st.text_input("Enter a query molecule in [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) format", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O")
416
  if smiles:
417
  smiles = [smiles]
418
  morgan = [morgan_from_smiles(s) for s in smiles]
@@ -436,6 +436,10 @@ def images_from_molecule():
436
  with col3:
437
  st.write("")
438
 
 
 
 
 
439
 
440
  img_features_torch = torch.load(image_features, map_location=device)
441
  img_features = img_features_torch["img_features"]
@@ -443,7 +447,7 @@ def images_from_molecule():
443
 
444
  logits = mol_features @ img_features.T
445
  img_probs = (30.0 * logits).softmax(dim=-1)
446
- top_probs, top_labels = img_probs.cpu().topk(5, dim=-1)
447
 
448
  top_probs = torch.flatten(top_probs)
449
  top_labels = torch.flatten(top_labels)
@@ -454,6 +458,7 @@ def images_from_molecule():
454
 
455
  images_dict = np.load(images_arr, allow_pickle = True)
456
 
 
457
  st.write("Retrieved images from the Cell Painting database")
458
  with st.container():
459
  columns = st.columns(len(top_probs))
 
411
  print((top_probs, top_labels))
412
 
413
  def images_from_molecule():
414
+ st.markdown("Enter a query molecule in [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) format",)
415
+ smiles = st.text_input("Enter a query molecule in SMILES format", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O", label_visibility="collapsed")
416
  if smiles:
417
  smiles = [smiles]
418
  morgan = [morgan_from_smiles(s) for s in smiles]
 
436
  with col3:
437
  st.write("")
438
 
439
+ top_n = st.selectbox(
440
+ "How many images would you like to be retrieve?",
441
+ ("5", "10", "20")
442
+ )
443
 
444
  img_features_torch = torch.load(image_features, map_location=device)
445
  img_features = img_features_torch["img_features"]
 
447
 
448
  logits = mol_features @ img_features.T
449
  img_probs = (30.0 * logits).softmax(dim=-1)
450
+ top_probs, top_labels = img_probs.cpu().topk(top_n, dim=-1)
451
 
452
  top_probs = torch.flatten(top_probs)
453
  top_labels = torch.flatten(top_labels)
 
458
 
459
  images_dict = np.load(images_arr, allow_pickle = True)
460
 
461
+
462
  st.write("Retrieved images from the Cell Painting database")
463
  with st.container():
464
  columns = st.columns(len(top_probs))