PopulationHealthScreener / notebooks /src /screen_articles.py
fulviodeo's picture
Refactoring
7d05933
raw
history blame
1.94 kB
"""
Screen a loaded article table against a fine-tuned model.
Run via exec_script() from the notebook. Reads the following keys from globals():
selected_model β€” a loaded ModelWrapper
recall_target β€” a Dropdown whose .value is an integer (95, 90, ...)
articles_table β€” a PredictionTable
_screen_progress β€” an IntProgress widget (or None)
_screen_progress_label β€” an HTML widget for batch labels (or None)
Writes to the namespace:
data β€” copy of articles_table.df with y_prob and y_pred columns added
review_data β€” rows where y_prob >= threshold, without y_prob/y_pred columns
"""
import numpy as np
selected_model = globals().get("selected_model")
recall_target_value = globals().get('recall_target').value
articles_table = globals().get("articles_table")
_progress = globals().get("_screen_progress")
_progress_label = globals().get("_screen_progress_label")
nlp = selected_model.model
data = articles_table.df.dropna(subset=[nlp.input_col])
x_test = data[nlp.input_col].astype(str).to_list()
x_test = [
nlp.tokenizer.decode(
nlp.tokenizer.encode(text, max_length=nlp.max_length, truncation=True),
skip_special_tokens=True
)
for text in x_test
]
threshold = selected_model.thresholds[recall_target_value]
def _on_batch(done, total):
if _progress is not None:
_progress.max = total
_progress.value = done
if _progress_label is not None:
_progress_label.value = (
f"<span style='font-size:12px; color:gray; margin-left:8px; line-height:20px;'>"
f"{done} / {total}</span>"
)
y_prob = nlp.predict(x_test, on_batch=_on_batch)
data = data.copy()
data['y_prob'] = y_prob
data['y_pred'] = (np.array(y_prob) >= threshold).astype(int)
review_data = data.loc[np.array(y_prob) >= threshold].copy()
review_data = review_data.drop(columns=['y_prob', 'y_pred'], errors='ignore')