|
|
|
|
|
"""Evaluate a saved `.pt` checkpoint on the validation split (optionally using a frozen protocol). |
|
|
|
|
|
Outputs: |
|
|
- Predictions CSV (for Streamlit Evaluation dashboard): columns `sample_id`, `class_0..`, `target_class_0..` |
|
|
- Metrics JSON (for model zoo + dashboards), including optional optimized global threshold. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import hashlib |
|
|
import json |
|
|
import logging |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Any |
|
|
|
|
|
import pandas as pd |
|
|
import torch |
|
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
|
|
|
from data.data_loader import load_data, split_data |
|
|
from data.transformer_dataset import TransformerNewsDataset |
|
|
from models.transformer_model import RussianNewsClassifier |
|
|
from utils.data_processing import create_target_encoding, process_tags |
|
|
from utils.text_processing import normalise_text |
|
|
from utils.tokenization import create_tokenizer |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _file_sha256(path: str | Path, chunk_size: int = 1024 * 1024) -> str: |
|
|
p = Path(path) |
|
|
h = hashlib.sha256() |
|
|
with p.open("rb") as f: |
|
|
while True: |
|
|
chunk = f.read(chunk_size) |
|
|
if not chunk: |
|
|
break |
|
|
h.update(chunk) |
|
|
return h.hexdigest() |
|
|
|
|
|
|
|
|
def _pick_device() -> torch.device: |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): |
|
|
return torch.device("mps") |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
def _metrics_from_binary(target: torch.Tensor, pred: torch.Tensor) -> dict[str, float]: |
|
|
""" |
|
|
Compute the same family of metrics used in existing `experiments/results/*.json`. |
|
|
- precision/recall/f1 are averaged per-sample (like `evaluation.metrics`) |
|
|
- exact_match is elementwise accuracy across all labels |
|
|
- subset_accuracy is strict set match per sample |
|
|
- micro_* are computed globally across all labels |
|
|
""" |
|
|
target = target.float() |
|
|
pred = pred.float() |
|
|
|
|
|
|
|
|
tp_per = ((pred == 1) & (target == 1)).sum(dim=1).float() |
|
|
pred_pos_per = (pred == 1).sum(dim=1).float() |
|
|
true_pos_per = (target == 1).sum(dim=1).float() |
|
|
|
|
|
precision = (tp_per / (pred_pos_per + 1e-5)).mean().item() |
|
|
recall = (tp_per / (true_pos_per + 1e-5)).mean().item() |
|
|
f1 = (2 * precision * recall) / (precision + recall + 1e-5) |
|
|
|
|
|
exact_match = (pred == target).float().mean().item() |
|
|
subset_accuracy = (pred == target).all(dim=1).float().mean().item() |
|
|
|
|
|
tp = ((pred == 1) & (target == 1)).sum().float() |
|
|
fp = ((pred == 1) & (target == 0)).sum().float() |
|
|
fn = ((pred == 0) & (target == 1)).sum().float() |
|
|
|
|
|
micro_precision = (tp / (tp + fp + 1e-5)).item() |
|
|
micro_recall = (tp / (tp + fn + 1e-5)).item() |
|
|
micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5) |
|
|
|
|
|
return { |
|
|
"precision": float(precision), |
|
|
"recall": float(recall), |
|
|
"f1": float(f1), |
|
|
"exact_match": float(exact_match), |
|
|
"subset_accuracy": float(subset_accuracy), |
|
|
"micro_precision": float(micro_precision), |
|
|
"micro_recall": float(micro_recall), |
|
|
"micro_f1": float(micro_f1), |
|
|
} |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def _predict_probs( |
|
|
*, |
|
|
model: RussianNewsClassifier, |
|
|
dataset: TransformerNewsDataset, |
|
|
batch_size: int, |
|
|
device: torch.device, |
|
|
) -> tuple[torch.Tensor, torch.Tensor, list[str]]: |
|
|
"""Return (probs, targets, sample_ids).""" |
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
|
|
|
|
|
probs_list: list[torch.Tensor] = [] |
|
|
targets_list: list[torch.Tensor] = [] |
|
|
sample_ids: list[str] = [] |
|
|
|
|
|
|
|
|
if "href" in dataset.df.columns: |
|
|
ids = dataset.df["href"].astype(str).tolist() |
|
|
else: |
|
|
ids = dataset.df.index.astype(str).tolist() |
|
|
|
|
|
offset = 0 |
|
|
for batch in loader: |
|
|
bsz = batch["labels"].shape[0] |
|
|
sample_ids.extend(ids[offset : offset + bsz]) |
|
|
offset += bsz |
|
|
|
|
|
batch_device: dict[str, torch.Tensor] = {} |
|
|
for k, v in batch.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
batch_device[k] = v.to(device) |
|
|
logits = model( |
|
|
title_input_ids=batch_device["title_input_ids"], |
|
|
title_attention_mask=batch_device["title_attention_mask"], |
|
|
snippet_input_ids=batch_device.get("snippet_input_ids"), |
|
|
snippet_attention_mask=batch_device.get("snippet_attention_mask"), |
|
|
) |
|
|
probs = torch.sigmoid(logits).detach().cpu() |
|
|
probs_list.append(probs) |
|
|
targets_list.append(batch["labels"].detach().cpu()) |
|
|
|
|
|
probs_all = torch.cat(probs_list, dim=0) if probs_list else torch.empty((0, 0)) |
|
|
targets_all = torch.cat(targets_list, dim=0) if targets_list else torch.empty((0, 0)) |
|
|
return probs_all, targets_all, sample_ids |
|
|
|
|
|
|
|
|
def _optimize_threshold( |
|
|
*, |
|
|
probs: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
metric: str, |
|
|
min_t: float = 0.01, |
|
|
max_t: float = 0.99, |
|
|
step: float = 0.01, |
|
|
) -> tuple[float, dict[str, float]]: |
|
|
if probs.numel() == 0: |
|
|
return 0.5, _metrics_from_binary(target, probs) |
|
|
|
|
|
if metric not in {"precision", "recall", "f1"}: |
|
|
raise ValueError(f"Unknown optimize metric: {metric}") |
|
|
|
|
|
best_t = 0.5 |
|
|
best_val = -1.0 |
|
|
best_metrics: dict[str, float] = {} |
|
|
|
|
|
t = min_t |
|
|
while t <= max_t + 1e-9: |
|
|
pred = (probs >= t).float() |
|
|
m = _metrics_from_binary(target, pred) |
|
|
score = m[metric] |
|
|
if score > best_val: |
|
|
best_val = score |
|
|
best_t = float(round(t, 2)) |
|
|
best_metrics = m |
|
|
t = round(t + step, 10) |
|
|
|
|
|
return best_t, best_metrics |
|
|
|
|
|
|
|
|
def main() -> int: |
|
|
parser = argparse.ArgumentParser(description="Evaluate a trained model checkpoint") |
|
|
parser.add_argument("--checkpoint", type=str, required=True, help="Path to saved `.pt` checkpoint") |
|
|
parser.add_argument("--data-path", type=str, default="data/news_data/ria_news.tsv", help="Path to RIA TSV") |
|
|
parser.add_argument("--protocol-dir", type=str, default=None, help="Frozen protocol directory (splits.json + tag_to_idx.json)") |
|
|
parser.add_argument("--max-val-samples", type=int, default=None, help="Limit validation samples (ignored if protocol-dir is set)") |
|
|
parser.add_argument("--threshold", type=float, default=0.5, help="Default global threshold for reporting `metrics`") |
|
|
|
|
|
parser.add_argument("--optimize-threshold", action="store_true", help="Search for best global threshold on val set") |
|
|
parser.add_argument( |
|
|
"--optimize-metric", |
|
|
type=str, |
|
|
default="f1", |
|
|
choices=["precision", "recall", "f1"], |
|
|
help="Metric to optimize when --optimize-threshold is set", |
|
|
) |
|
|
|
|
|
parser.add_argument("--batch-size", type=int, default=16, help="Eval batch size") |
|
|
parser.add_argument("--model-id", type=str, default=None, help="Optional model identifier (defaults to checkpoint stem)") |
|
|
parser.add_argument("--output-csv", type=str, default=None, help="Write predictions CSV to this path") |
|
|
parser.add_argument("--metrics-json", type=str, default=None, help="Write metrics JSON to this path") |
|
|
args = parser.parse_args() |
|
|
|
|
|
ckpt_path = Path(args.checkpoint) |
|
|
if not ckpt_path.exists(): |
|
|
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") |
|
|
|
|
|
checkpoint: dict[str, Any] = torch.load(ckpt_path, map_location="cpu") |
|
|
tag_to_idx = checkpoint.get("tag_to_idx") or {} |
|
|
num_labels = int(checkpoint.get("num_labels") or len(tag_to_idx)) |
|
|
model_name = checkpoint.get("model_name") or "DeepPavlov/rubert-base-cased" |
|
|
use_snippet = bool(checkpoint.get("use_snippet", False)) |
|
|
|
|
|
model_id = args.model_id or ckpt_path.stem |
|
|
|
|
|
logger.info(f"Loading data from {args.data_path}...") |
|
|
df_ria, _, _ = load_data(args.data_path) |
|
|
|
|
|
logger.info("Processing text...") |
|
|
df_ria["title_clean"] = df_ria["title"].apply(normalise_text) |
|
|
if "snippet" in df_ria.columns: |
|
|
df_ria["snippet_clean"] = df_ria["snippet"].fillna("").apply(normalise_text) |
|
|
|
|
|
logger.info("Processing tags...") |
|
|
df_ria["tags"] = process_tags(df_ria["tags"]) |
|
|
|
|
|
logger.info("Splitting data...") |
|
|
df_train, df_val, df_test = split_data( |
|
|
df_ria, |
|
|
train_date_end="2018-10-01", |
|
|
val_date_start="2018-10-01", |
|
|
val_date_end="2018-12-01", |
|
|
test_date_start="2018-12-01", |
|
|
) |
|
|
|
|
|
protocol_meta: dict[str, Any] | None = None |
|
|
if args.protocol_dir: |
|
|
protocol_path = Path(args.protocol_dir) |
|
|
splits_path = protocol_path / "splits.json" |
|
|
mapping_path = protocol_path / "tag_to_idx.json" |
|
|
if not splits_path.exists() or not mapping_path.exists(): |
|
|
raise FileNotFoundError(f"protocol-dir must contain splits.json and tag_to_idx.json: {protocol_path}") |
|
|
|
|
|
splits = json.loads(splits_path.read_text(encoding="utf-8")) |
|
|
id_col = splits.get("id_column", "href") |
|
|
if id_col == "href" and "href" in df_val.columns: |
|
|
df_train = df_train[df_train["href"].astype(str).isin(set(splits["train_ids"]))].copy() |
|
|
df_val = df_val[df_val["href"].astype(str).isin(set(splits["val_ids"]))].copy() |
|
|
df_test = df_test[df_test["href"].astype(str).isin(set(splits["test_ids"]))].copy() |
|
|
else: |
|
|
train_ids = set(splits["train_ids"]) |
|
|
val_ids = set(splits["val_ids"]) |
|
|
test_ids = set(splits["test_ids"]) |
|
|
df_train = df_train[df_train.index.astype(str).isin(train_ids)].copy() |
|
|
df_val = df_val[df_val.index.astype(str).isin(val_ids)].copy() |
|
|
df_test = df_test[df_test.index.astype(str).isin(test_ids)].copy() |
|
|
|
|
|
tag_to_idx = json.loads(mapping_path.read_text(encoding="utf-8")) |
|
|
num_labels = len(tag_to_idx) |
|
|
logger.info( |
|
|
f"Loaded protocol bundle from {protocol_path} " |
|
|
f"(train={len(df_train)}, val={len(df_val)}, test={len(df_test)}, labels={num_labels})" |
|
|
) |
|
|
|
|
|
protocol_meta = { |
|
|
"data_path": args.data_path, |
|
|
"data_sha256": _file_sha256(args.data_path), |
|
|
"split": { |
|
|
"train_date_end": "2018-10-01", |
|
|
"val_date_start": "2018-10-01", |
|
|
"val_date_end": "2018-12-01", |
|
|
"test_date_start": "2018-12-01", |
|
|
}, |
|
|
"limits": { |
|
|
"max_train_samples": len(df_train), |
|
|
"max_val_samples": len(df_val), |
|
|
}, |
|
|
"label_space": { |
|
|
"min_tag_frequency": None, |
|
|
"num_labels": num_labels, |
|
|
}, |
|
|
} |
|
|
|
|
|
else: |
|
|
if args.max_val_samples is not None: |
|
|
df_val = df_val.head(args.max_val_samples).copy() |
|
|
|
|
|
logger.info(f"Val samples: {len(df_val)}") |
|
|
|
|
|
|
|
|
df_val = df_val.copy() |
|
|
df_val["target_tags"] = create_target_encoding(df_val, tag_to_idx) |
|
|
|
|
|
tokenizer = create_tokenizer(model_name, max_length=128) |
|
|
val_dataset = TransformerNewsDataset( |
|
|
df=df_val, |
|
|
tokenizer=tokenizer, |
|
|
max_title_len=128, |
|
|
max_snippet_len=256 if use_snippet else None, |
|
|
label_to_idx=tag_to_idx, |
|
|
) |
|
|
|
|
|
model = RussianNewsClassifier( |
|
|
model_name=model_name, |
|
|
num_labels=num_labels, |
|
|
dropout=float(checkpoint.get("dropout", 0.3)), |
|
|
use_snippet=use_snippet, |
|
|
freeze_bert=bool(checkpoint.get("freeze_backbone", False)), |
|
|
) |
|
|
model.load_state_dict(checkpoint["state_dict"], strict=True) |
|
|
|
|
|
device = _pick_device() |
|
|
logger.info(f"Evaluating on device: {device}") |
|
|
|
|
|
probs, target, sample_ids = _predict_probs(model=model, dataset=val_dataset, batch_size=args.batch_size, device=device) |
|
|
|
|
|
|
|
|
if args.output_csv: |
|
|
out_csv = Path(args.output_csv) |
|
|
out_csv.parent.mkdir(parents=True, exist_ok=True) |
|
|
data: dict[str, Any] = {"sample_id": sample_ids} |
|
|
for j in range(probs.shape[1]): |
|
|
data[f"class_{j}"] = probs[:, j].numpy() |
|
|
for j in range(target.shape[1]): |
|
|
data[f"target_class_{j}"] = target[:, j].numpy() |
|
|
pd.DataFrame(data).to_csv(out_csv, index=False) |
|
|
logger.info(f"Wrote predictions CSV: {out_csv}") |
|
|
|
|
|
|
|
|
pred_default = (probs >= float(args.threshold)).float() |
|
|
metrics_default = _metrics_from_binary(target, pred_default) |
|
|
|
|
|
|
|
|
sanity = { |
|
|
"avg_true_labels_per_sample": float(target.sum(dim=1).float().mean().item()), |
|
|
"avg_pred_labels_per_sample": float(pred_default.sum(dim=1).float().mean().item()), |
|
|
"pct_samples_with_any_true_label": float((target.sum(dim=1) > 0).float().mean().item()), |
|
|
"pct_samples_with_any_pred_label": float((pred_default.sum(dim=1) > 0).float().mean().item()), |
|
|
"prob_min": float(probs.min().item()) if probs.numel() else 0.0, |
|
|
"prob_mean": float(probs.mean().item()) if probs.numel() else 0.0, |
|
|
"prob_max": float(probs.max().item()) if probs.numel() else 0.0, |
|
|
} |
|
|
|
|
|
payload: dict[str, Any] = { |
|
|
"experiment_id": model_id, |
|
|
"checkpoint_path": str(args.checkpoint), |
|
|
"data_path": args.data_path, |
|
|
"protocol_dir": args.protocol_dir, |
|
|
"protocol": protocol_meta, |
|
|
"threshold": float(args.threshold), |
|
|
"max_val_samples": args.max_val_samples, |
|
|
"val_samples": int(target.shape[0]), |
|
|
"num_labels": int(target.shape[1]), |
|
|
"model_name": model_name, |
|
|
"use_snippet": bool(use_snippet), |
|
|
"metrics": metrics_default, |
|
|
"sanity": sanity, |
|
|
} |
|
|
|
|
|
if args.optimize_threshold: |
|
|
best_t, best_metrics = _optimize_threshold( |
|
|
probs=probs, |
|
|
target=target, |
|
|
metric=args.optimize_metric, |
|
|
min_t=0.01, |
|
|
max_t=0.99, |
|
|
step=0.01, |
|
|
) |
|
|
payload["optimized_threshold"] = { |
|
|
"threshold": float(best_t), |
|
|
"metric": args.optimize_metric, |
|
|
"metric_value": float(best_metrics[args.optimize_metric]), |
|
|
**best_metrics, |
|
|
} |
|
|
|
|
|
if args.metrics_json: |
|
|
out_json = Path(args.metrics_json) |
|
|
out_json.parent.mkdir(parents=True, exist_ok=True) |
|
|
out_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") |
|
|
logger.info(f"Wrote metrics JSON: {out_json}") |
|
|
|
|
|
|
|
|
logger.info(f"Metrics @ threshold={args.threshold}: f1={metrics_default['f1']:.4f}, p={metrics_default['precision']:.4f}, r={metrics_default['recall']:.4f}") |
|
|
if args.optimize_threshold: |
|
|
opt = payload["optimized_threshold"] |
|
|
logger.info( |
|
|
f"Optimized threshold={opt['threshold']:.2f} ({opt['metric']}={opt['metric_value']:.4f}) " |
|
|
f"f1={opt['f1']:.4f}, p={opt['precision']:.4f}, r={opt['recall']:.4f}" |
|
|
) |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
raise SystemExit(main()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|