alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import os
import json
import numpy as np
import torch
from tqdm import tqdm
from .metrics import precision_recall_f1_wd
from .load_dataset import load_dataset
from research_files.inference import load_model
def evaluate_proposed(
model_path: str,
tokenizer_path: str,
max_sentences: int,
data_path: str,
batch_size: int = 32,
logit_th: float = 1.9,
device: torch.device = torch.device('cpu')
):
"""
Evaluates the proposed sentence-level segmentation model on a dataset
and reports boundary-based segmentation metrics.
The evaluation pipeline:
1. Loads the dataset and generates sentence-level batches with gold
boundary labels and masks.
2. Loads a pretrained segmentation model and its tokenizer.
3. Runs batched inference to predict sentence boundaries.
4. Computes precision, recall, F1-score, and WindowDiff for each batch.
5. Saves per-batch metric values to a JSON file named after the model.
Args:
model_path (str):
Path to the trained model checkpoint.
tokenizer_path (str):
Path or identifier of the tokenizer associated with the model.
max_sentences (int):
Maximum number of sentences per document used during evaluation.
Longer documents are truncated or split consistently with training.
data_path (str):
Path to the dataset directory (Hugging Face `load_from_disk` format).
batch_size (int, optional):
Number of samples evaluated per batch. Defaults to 32.
logit_th (float, optional):
Threshold for the logistic regression loss. Defaults to 1.9.
device (torch.device, optional):
Device to run the model on. Defaults to torch.device('cpu').
Returns:
None
Evaluation results are written to a JSON file containing lists of
precision, recall, F1-score, and WindowDiff values for each batch.
"""
# Load data and tokenizer:
dataset = load_dataset(data_path, max_sentences=max_sentences, batch_size=batch_size)
model, tokenizer, _ = load_model(model_path, tokenizer_path)
model.to(device).eval()
# Build output lists:
p_list = list()
r_list = list()
f1_list = list()
wd_list = list()
# Pbar:
with tqdm(total=np.ceil(6557 / batch_size).astype(np.int32), desc='Evaluating on test...') as pbar:
# Iterate over batches:
for batch_x, batch_y, y_mask in dataset:
# Tokenize text:
batch_in = [tokenizer(_x) for _x in batch_x]
x_batched = torch.stack([_['input_ids'] for _ in batch_in]).to(device)
batch_mask = torch.stack([_['attention_mask'] for _ in batch_in]).to(device)
y_batched = torch.from_numpy(np.array(batch_y)).to(device)
y_mask = torch.from_numpy(np.array(y_mask)).to(device)
# Run batch:
y_hat = model(x_batched, batch_mask)
y_hat = (y_hat >= logit_th).int()
y_hat[..., 0] = 1
p, r, f1, wd = precision_recall_f1_wd(y_hat, y_batched, y_mask)
p_list.append(p)
r_list.append(r)
f1_list.append(f1)
wd_list.append(wd)
# Set postfix:
pbar.update(1)
pbar.set_postfix({
'f1': np.mean(f1_list),
'wd': np.mean(wd_list)
})
# Save results:
model_name = model_path.split('/')[-1]
with open(os.path.join(f'{model_name}.json'), 'w') as f:
json.dump({
'precision': p_list,
'recall': r_list,
'f1': f1_list,
'window_diff': wd_list,
'logit_th': logit_th,
'max_sentences': max_sentences
}, f, indent=4)
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #