Update app.py
Browse files
app.py
CHANGED
|
@@ -99,7 +99,8 @@ st.markdown(base_css, unsafe_allow_html=True)
|
|
| 99 |
MODEL_REPO = "ym59/velobind-models"
|
| 100 |
MODEL_DIR = Path("output/models")
|
| 101 |
PREP_DIR = Path("output/preprocessors")
|
| 102 |
-
|
|
|
|
| 103 |
|
| 104 |
_DESC_FNS: Optional[List[Any]] = None
|
| 105 |
try:
|
|
@@ -176,12 +177,11 @@ def load_models() -> Tuple[Dict[str, Any], Optional[Any], Optional[Any], Optiona
|
|
| 176 |
except Exception:
|
| 177 |
pass
|
| 178 |
|
| 179 |
-
if
|
| 180 |
try:
|
| 181 |
-
train_embs = np.load(str(
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
ad_threshold = float(np.load(str(at)))
|
| 185 |
except Exception:
|
| 186 |
pass
|
| 187 |
|
|
@@ -375,15 +375,15 @@ def predict_pkd(X: np.ndarray, fold_models: Dict[str, Any], meta: Any, iso_cal:
|
|
| 375 |
|
| 376 |
def check_ad(esm_mean: np.ndarray, train_embs: Optional[np.ndarray], ad_threshold: float) -> Tuple[bool, float]:
|
| 377 |
if train_embs is None:
|
| 378 |
-
return
|
| 379 |
try:
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
return
|
| 387 |
|
| 388 |
|
| 389 |
def clean_fasta(s: str) -> str:
|
|
|
|
| 99 |
MODEL_REPO = "ym59/velobind-models"
|
| 100 |
MODEL_DIR = Path("output/models")
|
| 101 |
PREP_DIR = Path("output/preprocessors")
|
| 102 |
+
AD_CENTROID_PATH = Path("output/models/deployment/ad_centroid.npy")
|
| 103 |
+
AD_THRESHOLD_PATH = Path("output/models/deployment/ad_threshold.npy")
|
| 104 |
|
| 105 |
_DESC_FNS: Optional[List[Any]] = None
|
| 106 |
try:
|
|
|
|
| 177 |
except Exception:
|
| 178 |
pass
|
| 179 |
|
| 180 |
+
if AD_CENTROID_PATH.exists():
|
| 181 |
try:
|
| 182 |
+
train_embs = np.load(str(AD_CENTROID_PATH))
|
| 183 |
+
if AD_THRESHOLD_PATH.exists():
|
| 184 |
+
ad_threshold = float(np.load(str(AD_THRESHOLD_PATH)))
|
|
|
|
| 185 |
except Exception:
|
| 186 |
pass
|
| 187 |
|
|
|
|
| 375 |
|
| 376 |
def check_ad(esm_mean: np.ndarray, train_embs: Optional[np.ndarray], ad_threshold: float) -> Tuple[bool, float]:
|
| 377 |
if train_embs is None:
|
| 378 |
+
return False, 0.0 # Fail safely to OUT OF DOMAIN if files are missing
|
| 379 |
try:
|
| 380 |
+
q = esm_mean[-480:]
|
| 381 |
+
# Calculate Euclidean distance to the centroid
|
| 382 |
+
dist = float(np.linalg.norm(q - train_embs))
|
| 383 |
+
return dist <= ad_threshold, dist
|
| 384 |
+
except Exception as e:
|
| 385 |
+
logger.debug("check_ad error: %s", e)
|
| 386 |
+
return False, 0.0
|
| 387 |
|
| 388 |
|
| 389 |
def clean_fasta(s: str) -> str:
|