| import argparse |
| import os |
| from typing import List, Tuple |
|
|
| import torch |
| import numpy as np |
| from datasets import Dataset |
| from transformers import AutoTokenizer, AutoModelForTokenClassification |
| from seqeval.metrics import precision_score, recall_score, f1_score, classification_report |
|
|
|
|
| def read_conll_2col(path: str) -> Tuple[List[List[str]], List[List[str]]]: |
| """Reads 2-column CoNLL (TOKEN TAG) with blank lines between sentences.""" |
| toks, labs = [], [] |
| all_toks, all_labs = [], [] |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.rstrip("\n") |
| if not line: |
| if toks: |
| all_toks.append(toks) |
| all_labs.append(labs) |
| toks, labs = [], [] |
| continue |
| parts = line.split() |
| if len(parts) < 2: |
| |
| continue |
| tok, tag = parts[0], parts[-1] |
| toks.append(tok) |
| labs.append(tag) |
| if toks: |
| all_toks.append(toks) |
| all_labs.append(labs) |
| return all_toks, all_labs |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_dir", type=str, default="outputs/bert-base-cased-timeNER", |
| help="Path to the fine-tuned model directory (with config.json, tokenizer files, weights).") |
| parser.add_argument("--test_path", type=str, default="data/test.conll", |
| help="Path to 2-column CoNLL test file.") |
| parser.add_argument("--batch_size", type=int, default=16) |
| parser.add_argument("--max_length", type=int, default=256) |
| args = parser.parse_args() |
|
|
| assert os.path.exists(args.model_dir), f"Model dir not found: {args.model_dir}" |
| assert os.path.exists(args.test_path), f"Test file not found: {args.test_path}" |
|
|
| |
| print(f"Loading model from: {args.model_dir}") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True) |
| model = AutoModelForTokenClassification.from_pretrained(args.model_dir) |
| model.eval() |
|
|
| |
| id2label = model.config.id2label |
| label2id = model.config.label2id |
| labels_sorted = [id2label[i] for i in range(len(id2label))] |
| print(f"Model labels: {labels_sorted}") |
|
|
| |
| print(f"Reading test set: {args.test_path}") |
| tokens_list, tags_list = read_conll_2col(args.test_path) |
| num_sents = len(tokens_list) |
| num_tokens = sum(len(s) for s in tokens_list) |
| print(f"Loaded {num_sents} sentences / {num_tokens} tokens") |
|
|
| |
| uniq_test_labels = sorted({t for seq in tags_list for t in seq}) |
| missing = [t for t in uniq_test_labels if t not in label2id] |
| if missing: |
| print(f"⚠️ Warning: test labels not in model: {missing}") |
|
|
| |
| ds = Dataset.from_dict({"tokens": tokens_list, "ner_tags": tags_list}) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| |
|
|
| |
| all_preds: List[List[str]] = [] |
| all_refs: List[List[str]] = [] |
|
|
| |
| for start in range(0, len(ds), args.batch_size): |
| batch = ds[start : start + args.batch_size] |
| batch_tokens = batch["tokens"] |
| batch_refs = batch["ner_tags"] |
|
|
| |
| encodings = tokenizer( |
| batch_tokens, |
| is_split_into_words=True, |
| truncation=True, |
| max_length=args.max_length, |
| return_tensors="pt", |
| padding=True, |
| ) |
|
|
| with torch.no_grad(): |
| logits = model( |
| input_ids=encodings["input_ids"].to(device), |
| attention_mask=encodings["attention_mask"].to(device), |
| token_type_ids=encodings.get("token_type_ids", None).to(device) if "token_type_ids" in encodings else None, |
| ).logits |
|
|
| pred_ids = logits.argmax(dim=-1).cpu().numpy() |
|
|
| |
| for i, word_labels in enumerate(batch_refs): |
| encoding = encodings.encodings[i] |
| word_ids = encoding.word_ids |
| seq_pred_ids = pred_ids[i] |
|
|
| word_level_preds: List[str] = [] |
| seen_word = None |
| for tok_idx, wid in enumerate(word_ids): |
| if wid is None: |
| continue |
| if wid != seen_word: |
| |
| label_id = int(seq_pred_ids[tok_idx]) |
| word_level_preds.append(id2label[label_id]) |
| seen_word = wid |
| else: |
| |
| continue |
|
|
| |
| L = min(len(word_labels), len(word_level_preds)) |
| all_refs.append(word_labels[:L]) |
| all_preds.append(word_level_preds[:L]) |
|
|
| |
| p = precision_score(all_refs, all_preds) |
| r = recall_score(all_refs, all_preds) |
| f1 = f1_score(all_refs, all_preds) |
|
|
| print("\n Results on test set") |
| print(f"Precision: {p:.4f}") |
| print(f"Recall : {r:.4f}") |
| print(f"F1 : {f1:.4f}") |
|
|
| print("\nSeqeval classification report") |
| print(classification_report(all_refs, all_preds, digits=4)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |