Spaces:
Sleeping
Sleeping
| """ | |
| Screen a loaded article table against a fine-tuned model. | |
| Run via exec_script() from the notebook. Reads the following keys from globals(): | |
| selected_model β a loaded ModelWrapper | |
| recall_target β a Dropdown whose .value is an integer (95, 90, ...) | |
| articles_table β a PredictionTable | |
| _screen_progress β an IntProgress widget (or None) | |
| _screen_progress_label β an HTML widget for batch labels (or None) | |
| Writes to the namespace: | |
| data β copy of articles_table.df with y_prob and y_pred columns added | |
| review_data β rows where y_prob >= threshold, without y_prob/y_pred columns | |
| """ | |
| import numpy as np | |
| selected_model = globals().get("selected_model") | |
| recall_target_value = globals().get('recall_target').value | |
| articles_table = globals().get("articles_table") | |
| _progress = globals().get("_screen_progress") | |
| _progress_label = globals().get("_screen_progress_label") | |
| nlp = selected_model.model | |
| data = articles_table.df.dropna(subset=[nlp.input_col]) | |
| x_test = data[nlp.input_col].astype(str).to_list() | |
| x_test = [ | |
| nlp.tokenizer.decode( | |
| nlp.tokenizer.encode(text, max_length=nlp.max_length, truncation=True), | |
| skip_special_tokens=True | |
| ) | |
| for text in x_test | |
| ] | |
| threshold = selected_model.thresholds[recall_target_value] | |
| def _on_batch(done, total): | |
| if _progress is not None: | |
| _progress.max = total | |
| _progress.value = done | |
| if _progress_label is not None: | |
| _progress_label.value = ( | |
| f"<span style='font-size:12px; color:gray; margin-left:8px; line-height:20px;'>" | |
| f"{done} / {total}</span>" | |
| ) | |
| y_prob = nlp.predict(x_test, on_batch=_on_batch) | |
| data = data.copy() | |
| data['y_prob'] = y_prob | |
| data['y_pred'] = (np.array(y_prob) >= threshold).astype(int) | |
| review_data = data.loc[np.array(y_prob) >= threshold].copy() | |
| review_data = review_data.drop(columns=['y_prob', 'y_pred'], errors='ignore') |