ym59 commited on
Commit
31465a4
·
verified ·
1 Parent(s): a2b5e4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -99,7 +99,8 @@ st.markdown(base_css, unsafe_allow_html=True)
99
  MODEL_REPO = "ym59/velobind-models"
100
  MODEL_DIR = Path("output/models")
101
  PREP_DIR = Path("output/preprocessors")
102
- AD_EMB_PATH = Path("output/ad_train_embeddings.npy")
 
103
 
104
  _DESC_FNS: Optional[List[Any]] = None
105
  try:
@@ -176,12 +177,11 @@ def load_models() -> Tuple[Dict[str, Any], Optional[Any], Optional[Any], Optiona
176
  except Exception:
177
  pass
178
 
179
- if AD_EMB_PATH.exists():
180
  try:
181
- train_embs = np.load(str(AD_EMB_PATH))
182
- at = Path("output/ad_threshold.npy")
183
- if at.exists():
184
- ad_threshold = float(np.load(str(at)))
185
  except Exception:
186
  pass
187
 
@@ -375,15 +375,15 @@ def predict_pkd(X: np.ndarray, fold_models: Dict[str, Any], meta: Any, iso_cal:
375
 
376
  def check_ad(esm_mean: np.ndarray, train_embs: Optional[np.ndarray], ad_threshold: float) -> Tuple[bool, float]:
377
  if train_embs is None:
378
- return True, 0.0
379
  try:
380
- from sklearn.metrics.pairwise import cosine_distances
381
- q = esm_mean[-480:].reshape(1, -1)
382
- d = cosine_distances(q, train_embs[:2000])[0]
383
- k = float(np.sort(d)[:5].mean())
384
- return k <= ad_threshold, k
385
- except Exception:
386
- return True, 0.0
387
 
388
 
389
  def clean_fasta(s: str) -> str:
 
99
  MODEL_REPO = "ym59/velobind-models"
100
  MODEL_DIR = Path("output/models")
101
  PREP_DIR = Path("output/preprocessors")
102
+ AD_CENTROID_PATH = Path("output/models/deployment/ad_centroid.npy")
103
+ AD_THRESHOLD_PATH = Path("output/models/deployment/ad_threshold.npy")
104
 
105
  _DESC_FNS: Optional[List[Any]] = None
106
  try:
 
177
  except Exception:
178
  pass
179
 
180
+ if AD_CENTROID_PATH.exists():
181
  try:
182
+ train_embs = np.load(str(AD_CENTROID_PATH))
183
+ if AD_THRESHOLD_PATH.exists():
184
+ ad_threshold = float(np.load(str(AD_THRESHOLD_PATH)))
 
185
  except Exception:
186
  pass
187
 
 
375
 
376
  def check_ad(esm_mean: np.ndarray, train_embs: Optional[np.ndarray], ad_threshold: float) -> Tuple[bool, float]:
377
  if train_embs is None:
378
+ return False, 0.0 # Fail safely to OUT OF DOMAIN if files are missing
379
  try:
380
+ q = esm_mean[-480:]
381
+ # Calculate Euclidean distance to the centroid
382
+ dist = float(np.linalg.norm(q - train_embs))
383
+ return dist <= ad_threshold, dist
384
+ except Exception as e:
385
+ logger.debug("check_ad error: %s", e)
386
+ return False, 0.0
387
 
388
 
389
  def clean_fasta(s: str) -> str: