LukeFP commited on
Commit
9876e16
·
1 Parent(s): 90caab9

hot fix, excluding a disicpline

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -72,6 +72,8 @@ DISCIPLINE_MODEL_PATH = MODELS_DIR / "discipline_classifier_gemma_20260130_14084
72
  CONCEPT_MODEL_PATH = MODELS_DIR / "concept_conditioned_gemma_20260130_140842.pt"
73
  EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m"
74
 
 
 
75
  # ---------------------------------------------------------------------------
76
  # Globals (loaded once at startup)
77
  # ---------------------------------------------------------------------------
@@ -164,12 +166,18 @@ def predict(title: str, abstract: str, threshold: float, top_k: int):
164
  conc_logits = concept_model(emb_tensor, disc_probs_tensor)
165
  conc_probs = torch.sigmoid(conc_logits).cpu().numpy()[0]
166
 
167
- # Format discipline results
168
  disc_order = np.argsort(disc_probs)[::-1]
169
  disc_lines = []
170
- for rank, idx in enumerate(disc_order[:top_k], 1):
171
- prob = disc_probs[idx]
172
  label = discipline_labels[idx].get("label", f"Discipline_{idx}")
 
 
 
 
 
 
173
  marker = "**" if prob >= threshold else ""
174
  disc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}")
175
 
 
72
  CONCEPT_MODEL_PATH = MODELS_DIR / "concept_conditioned_gemma_20260130_140842.pt"
73
  EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m"
74
 
75
+ EXCLUDED_DISCIPLINES = {"Quantum Physics"}
76
+
77
  # ---------------------------------------------------------------------------
78
  # Globals (loaded once at startup)
79
  # ---------------------------------------------------------------------------
 
166
  conc_logits = concept_model(emb_tensor, disc_probs_tensor)
167
  conc_probs = torch.sigmoid(conc_logits).cpu().numpy()[0]
168
 
169
+ # Format discipline results (skip excluded labels)
170
  disc_order = np.argsort(disc_probs)[::-1]
171
  disc_lines = []
172
+ rank = 0
173
+ for idx in disc_order:
174
  label = discipline_labels[idx].get("label", f"Discipline_{idx}")
175
+ if label in EXCLUDED_DISCIPLINES:
176
+ continue
177
+ rank += 1
178
+ if rank > top_k:
179
+ break
180
+ prob = disc_probs[idx]
181
  marker = "**" if prob >= threshold else ""
182
  disc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}")
183