Ana Sanchez commited on
Commit
fa8b5a5
·
1 Parent(s): d87e105

update app

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -305,7 +305,7 @@ def reshape_image(arr):
305
  ##### STREAMLIT FUNCTIONS ######
306
  st.title('CLOOME. Bioimage database retrieval from chemical structures (and viceversa)')
307
 
308
- def main_page(top_n):
309
  st.markdown(
310
  """
311
  Contrastive learning for self-supervised representation learning has brought a strong improvement to many application areas, such as computer vision and natural language processing.
@@ -358,7 +358,7 @@ def molecules_from_image(top_n, model_path):
358
  molpath = os.path.join(datapath, "mols.hdf")
359
  fps_fname = save_hdf(morgan, molnames, molpath)
360
  mol_imgs = draw_molecules(smiles)
361
- mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution)
362
  predefined_features = False
363
  else:
364
  mol_index = pd.read_csv(mol_index_file)
@@ -375,9 +375,11 @@ def molecules_from_image(top_n, model_path):
375
  print(img_features.shape)
376
  print(mol_features.shape)
377
 
 
 
378
  logits = img_features @ mol_features.T
379
  mol_probs = (30.0 * logits).softmax(dim=-1)
380
- top_probs, top_labels = mol_probs.cpu().topk(5, dim=-1)
381
 
382
  # Delete this if want to allow retrieval for multiple images
383
  top_probs = torch.flatten(top_probs)
 
305
  ##### STREAMLIT FUNCTIONS ######
306
  st.title('CLOOME. Bioimage database retrieval from chemical structures (and viceversa)')
307
 
308
+ def main_page(top_n, model_path):
309
  st.markdown(
310
  """
311
  Contrastive learning for self-supervised representation learning has brought a strong improvement to many application areas, such as computer vision and natural language processing.
 
358
  molpath = os.path.join(datapath, "mols.hdf")
359
  fps_fname = save_hdf(morgan, molnames, molpath)
360
  mol_imgs = draw_molecules(smiles)
361
+ mol_features, mol_ids = main(mol_index, model_path, model_type, mol_path=molpath, image_resolution=image_resolution)
362
  predefined_features = False
363
  else:
364
  mol_index = pd.read_csv(mol_index_file)
 
375
  print(img_features.shape)
376
  print(mol_features.shape)
377
 
378
+ top_n = int(top_n)
379
+
380
  logits = img_features @ mol_features.T
381
  mol_probs = (30.0 * logits).softmax(dim=-1)
382
+ top_probs, top_labels = mol_probs.cpu().topk(top_n, dim=-1)
383
 
384
  # Delete this if want to allow retrieval for multiple images
385
  top_probs = torch.flatten(top_probs)