#!/usr/bin/env python3 import argparse import re import string from typing import List, Tuple import torch from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForTokenClassification, AutoModelForCausalLM, ) # BIO utilities def bio_to_spans(tags: List[str]) -> List[Tuple[int, int, str]]: spans = [] start = None label = None def close(i): nonlocal start, label if start is not None: spans.append((start, i, label)) start = None label = None for i, t in enumerate(tags): if t == "O": close(i) elif t.startswith("B-"): close(i) start = i label = t[2:] elif t.startswith("I-"): if start is None: start = i label = t[2:] else: close(i) close(len(tags)) return spans def micro_f1(pred_spans_list, gold_spans_list): tp = fp = fn = 0 for p, g in zip(pred_spans_list, gold_spans_list): p_set = set(p) g_set = set(g) tp += len(p_set & g_set) fp += len(p_set - g_set) fn += len(g_set - p_set) precision = tp / (tp + fp) if tp + fp > 0 else 0 recall = tp / (tp + fn) if tp + fn > 0 else 0 f1 = (2 * precision * recall / (precision + recall) if precision + recall > 0 else 0) return { "precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn, } # CoNLL constants CONLL_ID2TAG = { 0: "O", 1: "B-ORG", 2: "B-MISC", 3: "B-PER", 4: "I-PER", 5: "B-LOC", 6: "I-ORG", 7: "I-MISC", 8: "I-LOC", } # tner model uses verbose label names MODEL_LABEL_MAP = { "O": "O", "B-organization": "B-ORG", "I-organization": "I-ORG", "B-person": "B-PER", "I-person": "I-PER", "B-location": "B-LOC", "I-location": "I-LOC", "B-other": "B-MISC", "I-other": "I-MISC", } ENTITY_TYPE_PROMPTS = { "PER": "What persons are mentioned?", "ORG": "What organizations are mentioned?", "LOC": "What locations are mentioned?", "MISC": "What miscellaneous entities (nationalities, events, etc.) are mentioned?", } def gold_text_spans_for(tokens, tag_ids): """Convert CoNLL integer tags to a set of (text, label) tuples.""" tags = [CONLL_ID2TAG[t] for t in tag_ids] spans = bio_to_spans(tags) return set((" ".join(tokens[s:e]), lbl) for s, e, lbl in spans) # SQuAD utilities def normalize_answer(s): def lower(text): return text.lower() def remove_punc(text): return "".join( ch for ch in text if ch not in set(string.punctuation)) def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text): return " ".join(text.split()) return white_space_fix(remove_articles(remove_punc(lower(s)))) def squad_exact_match(pred, golds): pred_norm = normalize_answer(pred) return any(pred_norm == normalize_answer(g) for g in golds) # Neural IE (token classification) def run_conll_neural_ie(model_name, limit, batch_size): ds = load_dataset("tner/conll2003", split="test") if limit > 0: ds = ds.select(range(min(limit, len(ds)))) device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) model = AutoModelForTokenClassification.from_pretrained(model_name).to(device) model.eval() id2label = model.config.id2label all_pred, all_gold = [], [] for i in range(0, len(ds), batch_size): batch = ds[i:i + batch_size] tokens_batch = batch["tokens"] gold_tags_batch = batch["tags"] enc = tokenizer( tokens_batch, is_split_into_words=True, padding=True, truncation=True, return_tensors="pt", ).to(device) with torch.no_grad(): preds = torch.argmax(model(**enc).logits, dim=-1) for b in range(len(tokens_batch)): word_ids = enc.word_ids(batch_index=b) pred_tags, cur = [], None for idx, wid in enumerate(word_ids): if wid is None: continue if wid != cur: cur = wid raw = id2label[preds[b][idx].item()] pred_tags.append(MODEL_LABEL_MAP.get(raw, raw)) gold_tags = [CONLL_ID2TAG[t] for t in gold_tags_batch[b]] all_pred.append(bio_to_spans(pred_tags)) all_gold.append(bio_to_spans(gold_tags)) return micro_f1(all_pred, all_gold) # Cuckoo (next-token extraction) def cuckoo_extract(prompt, tokenizer, model, device): """Return list of extracted text spans from prompt.""" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device) with torch.no_grad(): tags = model(**inputs).logits[0].argmax(-1).cpu().tolist() # 0 = begin-of-span, 1 = continuation seqs = [] i = 0 while i < len(tags): if tags[i] == 0: s = i e = i i += 1 while i < len(tags) and tags[i] == 1: e = i i += 1 seqs.append((s, e + 1)) else: i += 1 return [ tokenizer.decode(inputs.input_ids[0, s:e]).strip() for s, e in seqs ] def run_conll_cuckoo(model_name, limit): ds = load_dataset("tner/conll2003", split="test") if limit > 0: ds = ds.select(range(min(limit, len(ds)))) device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForTokenClassification.from_pretrained(model_name).to(device) model.eval() all_pred, all_gold = [], [] for idx, ex in enumerate(ds): tokens = ex["tokens"] sentence = " ".join(tokens) gold = gold_text_spans_for(tokens, ex["tags"]) pred = set() for label, question in ENTITY_TYPE_PROMPTS.items(): prompt = f"User:\n\n{sentence}\n\nQuestion: {question}\n\nAssistant:" for ent in cuckoo_extract(prompt, tokenizer, model, device): if ent: pred.add((ent, label)) all_pred.append(pred) all_gold.append(gold) if (idx + 1) % 50 == 0: print(f" Cuckoo CoNLL: {idx+1}/{len(ds)}") return micro_f1(all_pred, all_gold) def run_squad_cuckoo(model_name, limit): ds = load_dataset("rajpurkar/squad", split="validation") if limit > 0: ds = ds.select(range(min(limit, len(ds)))) device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForTokenClassification.from_pretrained(model_name).to(device) model.eval() correct = 0 for idx, ex in enumerate(ds): prompt = ( f"User:\n\n{ex['context']}\n\n" f"Question: {ex['question']}\n\nAssistant:" ) answers = cuckoo_extract(prompt, tokenizer, model, device) pred = answers[0] if answers else "" if squad_exact_match(pred, ex["answers"]["text"]): correct += 1 if (idx + 1) % 50 == 0: print(f" Cuckoo SQuAD: {idx+1}/{len(ds)}, acc so far={correct/(idx+1):.4f}") return {"accuracy": correct / len(ds), "n": len(ds), "correct": correct} # UIE (prompting a generative LM) def load_llm(model_name, device): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, trust_remote_code=True, ).to(device) model.eval() if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer, model def llm_generate(prompt, tokenizer, model, device, max_new=128): """Generate text from a prompt using the LLM.""" # Try chat template first, fall back to raw prompt try: messages = [{"role": "user", "content": prompt}] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) except Exception: text = prompt inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new, do_sample=False, pad_token_id=tokenizer.pad_token_id, ) new_tokens = out[0][inputs.input_ids.shape[1]:] return tokenizer.decode(new_tokens, skip_special_tokens=True).strip() NER_PROMPT_TEMPLATE = """\ Extract all named entities from the text below. For each entity, output one line in the format: entity_text | TYPE TYPE must be exactly one of: PER, ORG, LOC, MISC If there are no entities, output: NONE Example: Text: EU rejects German call to boycott British lamb . Entities: EU | ORG German | MISC British | MISC Text: {sentence} Entities:""" def parse_ner_output(text): """Parse LLM NER output into a set of (entity_text, label) tuples.""" spans = set() for line in text.strip().split("\n"): line = line.strip().lstrip("- ").strip() if "|" not in line: continue parts = line.split("|") entity = parts[0].strip() label = parts[-1].strip().upper() if label in ("PER", "ORG", "LOC", "MISC") and entity and entity != "NONE": spans.add((entity, label)) return spans def run_conll_uie(model_name, limit): ds = load_dataset("tner/conll2003", split="test") if limit > 0: ds = ds.select(range(min(limit, len(ds)))) device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer, model = load_llm(model_name, device) all_pred, all_gold = [], [] for idx, ex in enumerate(ds): tokens = ex["tokens"] sentence = " ".join(tokens) gold = gold_text_spans_for(tokens, ex["tags"]) prompt = NER_PROMPT_TEMPLATE.format(sentence=sentence) output = llm_generate(prompt, tokenizer, model, device, max_new=128) pred = parse_ner_output(output) all_pred.append(pred) all_gold.append(gold) if (idx + 1) % 50 == 0: print(f" UIE CoNLL: {idx+1}/{len(ds)}") return micro_f1(all_pred, all_gold) QA_PROMPT_TEMPLATE = """\ Answer the question using only a short span from the context. Output ONLY the answer span, nothing else. Context: {context} Question: {question} Answer:""" def run_squad_uie(model_name, limit): ds = load_dataset("rajpurkar/squad", split="validation") if limit > 0: ds = ds.select(range(min(limit, len(ds)))) device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer, model = load_llm(model_name, device) correct = 0 for idx, ex in enumerate(ds): prompt = QA_PROMPT_TEMPLATE.format( context=ex["context"], question=ex["question"]) answer = llm_generate(prompt, tokenizer, model, device, max_new=32) # Take only first line answer = answer.split("\n")[0].strip() if squad_exact_match(answer, ex["answers"]["text"]): correct += 1 if (idx + 1) % 50 == 0: print(f" UIE SQuAD: {idx+1}/{len(ds)}, acc so far={correct/(idx+1):.4f}") return {"accuracy": correct / len(ds), "n": len(ds), "correct": correct} # ============================== # Main # ============================== def main(): parser = argparse.ArgumentParser() parser.add_argument("--method", choices=["neural_ie", "uie", "cuckoo"], required=True) parser.add_argument("--task", choices=["conll", "squad"], required=True) parser.add_argument("--model", required=True) parser.add_argument("--limit", type=int, default=200) parser.add_argument("--batch_size", type=int, default=32) args = parser.parse_args() tag = f"[{args.method.upper()} - {args.task.upper()}]" if args.method == "neural_ie": if args.task == "squad": print(f"{tag} Skipped (Neural IE does not support SQuAD)") return result = run_conll_neural_ie(args.model, args.limit, args.batch_size) elif args.method == "cuckoo": if args.task == "conll": result = run_conll_cuckoo(args.model, args.limit) else: result = run_squad_cuckoo(args.model, args.limit) elif args.method == "uie": if args.task == "conll": result = run_conll_uie(args.model, args.limit) else: result = run_squad_uie(args.model, args.limit) print(f"{tag} {result}") if __name__ == "__main__": main()