Spaces:
Sleeping
Sleeping
hot fix, excluding a disicpline
Browse files
app.py
CHANGED
|
@@ -72,6 +72,8 @@ DISCIPLINE_MODEL_PATH = MODELS_DIR / "discipline_classifier_gemma_20260130_14084
|
|
| 72 |
CONCEPT_MODEL_PATH = MODELS_DIR / "concept_conditioned_gemma_20260130_140842.pt"
|
| 73 |
EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m"
|
| 74 |
|
|
|
|
|
|
|
| 75 |
# ---------------------------------------------------------------------------
|
| 76 |
# Globals (loaded once at startup)
|
| 77 |
# ---------------------------------------------------------------------------
|
|
@@ -164,12 +166,18 @@ def predict(title: str, abstract: str, threshold: float, top_k: int):
|
|
| 164 |
conc_logits = concept_model(emb_tensor, disc_probs_tensor)
|
| 165 |
conc_probs = torch.sigmoid(conc_logits).cpu().numpy()[0]
|
| 166 |
|
| 167 |
-
# Format discipline results
|
| 168 |
disc_order = np.argsort(disc_probs)[::-1]
|
| 169 |
disc_lines = []
|
| 170 |
-
|
| 171 |
-
|
| 172 |
label = discipline_labels[idx].get("label", f"Discipline_{idx}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
marker = "**" if prob >= threshold else ""
|
| 174 |
disc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}")
|
| 175 |
|
|
|
|
| 72 |
CONCEPT_MODEL_PATH = MODELS_DIR / "concept_conditioned_gemma_20260130_140842.pt"
|
| 73 |
EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m"
|
| 74 |
|
| 75 |
+
EXCLUDED_DISCIPLINES = {"Quantum Physics"}
|
| 76 |
+
|
| 77 |
# ---------------------------------------------------------------------------
|
| 78 |
# Globals (loaded once at startup)
|
| 79 |
# ---------------------------------------------------------------------------
|
|
|
|
| 166 |
conc_logits = concept_model(emb_tensor, disc_probs_tensor)
|
| 167 |
conc_probs = torch.sigmoid(conc_logits).cpu().numpy()[0]
|
| 168 |
|
| 169 |
+
# Format discipline results (skip excluded labels)
|
| 170 |
disc_order = np.argsort(disc_probs)[::-1]
|
| 171 |
disc_lines = []
|
| 172 |
+
rank = 0
|
| 173 |
+
for idx in disc_order:
|
| 174 |
label = discipline_labels[idx].get("label", f"Discipline_{idx}")
|
| 175 |
+
if label in EXCLUDED_DISCIPLINES:
|
| 176 |
+
continue
|
| 177 |
+
rank += 1
|
| 178 |
+
if rank > top_k:
|
| 179 |
+
break
|
| 180 |
+
prob = disc_probs[idx]
|
| 181 |
marker = "**" if prob >= threshold else ""
|
| 182 |
disc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}")
|
| 183 |
|