| |
|
|
| import argparse |
| import re |
| import string |
| from typing import List, Tuple |
|
|
| import torch |
| from datasets import load_dataset |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForTokenClassification, |
| AutoModelForCausalLM, |
| ) |
|
|
|
|
|
|
| |
|
|
|
|
| def bio_to_spans(tags: List[str]) -> List[Tuple[int, int, str]]: |
| spans = [] |
| start = None |
| label = None |
|
|
| def close(i): |
| nonlocal start, label |
| if start is not None: |
| spans.append((start, i, label)) |
| start = None |
| label = None |
|
|
| for i, t in enumerate(tags): |
| if t == "O": |
| close(i) |
| elif t.startswith("B-"): |
| close(i) |
| start = i |
| label = t[2:] |
| elif t.startswith("I-"): |
| if start is None: |
| start = i |
| label = t[2:] |
| else: |
| close(i) |
|
|
| close(len(tags)) |
| return spans |
|
|
|
|
| def micro_f1(pred_spans_list, gold_spans_list): |
| tp = fp = fn = 0 |
| for p, g in zip(pred_spans_list, gold_spans_list): |
| p_set = set(p) |
| g_set = set(g) |
| tp += len(p_set & g_set) |
| fp += len(p_set - g_set) |
| fn += len(g_set - p_set) |
|
|
| precision = tp / (tp + fp) if tp + fp > 0 else 0 |
| recall = tp / (tp + fn) if tp + fn > 0 else 0 |
| f1 = (2 * precision * recall / (precision + recall) |
| if precision + recall > 0 else 0) |
| return { |
| "precision": precision, |
| "recall": recall, |
| "f1": f1, |
| "tp": tp, |
| "fp": fp, |
| "fn": fn, |
| } |
|
|
|
|
|
|
| |
|
|
|
|
| CONLL_ID2TAG = { |
| 0: "O", 1: "B-ORG", 2: "B-MISC", 3: "B-PER", 4: "I-PER", |
| 5: "B-LOC", 6: "I-ORG", 7: "I-MISC", 8: "I-LOC", |
| } |
|
|
| |
| MODEL_LABEL_MAP = { |
| "O": "O", |
| "B-organization": "B-ORG", "I-organization": "I-ORG", |
| "B-person": "B-PER", "I-person": "I-PER", |
| "B-location": "B-LOC", "I-location": "I-LOC", |
| "B-other": "B-MISC", "I-other": "I-MISC", |
| } |
|
|
| ENTITY_TYPE_PROMPTS = { |
| "PER": "What persons are mentioned?", |
| "ORG": "What organizations are mentioned?", |
| "LOC": "What locations are mentioned?", |
| "MISC": "What miscellaneous entities (nationalities, events, etc.) are mentioned?", |
| } |
|
|
|
|
| def gold_text_spans_for(tokens, tag_ids): |
| """Convert CoNLL integer tags to a set of (text, label) tuples.""" |
| tags = [CONLL_ID2TAG[t] for t in tag_ids] |
| spans = bio_to_spans(tags) |
| return set((" ".join(tokens[s:e]), lbl) for s, e, lbl in spans) |
|
|
|
|
|
|
| |
|
|
|
|
| def normalize_answer(s): |
| def lower(text): return text.lower() |
| def remove_punc(text): return "".join( |
| ch for ch in text if ch not in set(string.punctuation)) |
| def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) |
| def white_space_fix(text): return " ".join(text.split()) |
| return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
|
| def squad_exact_match(pred, golds): |
| pred_norm = normalize_answer(pred) |
| return any(pred_norm == normalize_answer(g) for g in golds) |
|
|
|
|
|
|
| |
|
|
|
|
| def run_conll_neural_ie(model_name, limit, batch_size): |
| ds = load_dataset("tner/conll2003", split="test") |
| if limit > 0: |
| ds = ds.select(range(min(limit, len(ds)))) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
| model = AutoModelForTokenClassification.from_pretrained(model_name).to(device) |
| model.eval() |
| id2label = model.config.id2label |
|
|
| all_pred, all_gold = [], [] |
|
|
| for i in range(0, len(ds), batch_size): |
| batch = ds[i:i + batch_size] |
| tokens_batch = batch["tokens"] |
| gold_tags_batch = batch["tags"] |
|
|
| enc = tokenizer( |
| tokens_batch, is_split_into_words=True, |
| padding=True, truncation=True, return_tensors="pt", |
| ).to(device) |
|
|
| with torch.no_grad(): |
| preds = torch.argmax(model(**enc).logits, dim=-1) |
|
|
| for b in range(len(tokens_batch)): |
| word_ids = enc.word_ids(batch_index=b) |
| pred_tags, cur = [], None |
| for idx, wid in enumerate(word_ids): |
| if wid is None: |
| continue |
| if wid != cur: |
| cur = wid |
| raw = id2label[preds[b][idx].item()] |
| pred_tags.append(MODEL_LABEL_MAP.get(raw, raw)) |
|
|
| gold_tags = [CONLL_ID2TAG[t] for t in gold_tags_batch[b]] |
| all_pred.append(bio_to_spans(pred_tags)) |
| all_gold.append(bio_to_spans(gold_tags)) |
|
|
| return micro_f1(all_pred, all_gold) |
|
|
|
|
|
|
| |
|
|
|
|
| def cuckoo_extract(prompt, tokenizer, model, device): |
| """Return list of extracted text spans from prompt.""" |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, |
| max_length=512).to(device) |
| with torch.no_grad(): |
| tags = model(**inputs).logits[0].argmax(-1).cpu().tolist() |
|
|
| |
| seqs = [] |
| i = 0 |
| while i < len(tags): |
| if tags[i] == 0: |
| s = i |
| e = i |
| i += 1 |
| while i < len(tags) and tags[i] == 1: |
| e = i |
| i += 1 |
| seqs.append((s, e + 1)) |
| else: |
| i += 1 |
|
|
| return [ |
| tokenizer.decode(inputs.input_ids[0, s:e]).strip() |
| for s, e in seqs |
| ] |
|
|
|
|
| def run_conll_cuckoo(model_name, limit): |
| ds = load_dataset("tner/conll2003", split="test") |
| if limit > 0: |
| ds = ds.select(range(min(limit, len(ds)))) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForTokenClassification.from_pretrained(model_name).to(device) |
| model.eval() |
|
|
| all_pred, all_gold = [], [] |
|
|
| for idx, ex in enumerate(ds): |
| tokens = ex["tokens"] |
| sentence = " ".join(tokens) |
| gold = gold_text_spans_for(tokens, ex["tags"]) |
|
|
| pred = set() |
| for label, question in ENTITY_TYPE_PROMPTS.items(): |
| prompt = f"User:\n\n{sentence}\n\nQuestion: {question}\n\nAssistant:" |
| for ent in cuckoo_extract(prompt, tokenizer, model, device): |
| if ent: |
| pred.add((ent, label)) |
|
|
| all_pred.append(pred) |
| all_gold.append(gold) |
|
|
| if (idx + 1) % 50 == 0: |
| print(f" Cuckoo CoNLL: {idx+1}/{len(ds)}") |
|
|
| return micro_f1(all_pred, all_gold) |
|
|
|
|
| def run_squad_cuckoo(model_name, limit): |
| ds = load_dataset("rajpurkar/squad", split="validation") |
| if limit > 0: |
| ds = ds.select(range(min(limit, len(ds)))) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForTokenClassification.from_pretrained(model_name).to(device) |
| model.eval() |
|
|
| correct = 0 |
| for idx, ex in enumerate(ds): |
| prompt = ( |
| f"User:\n\n{ex['context']}\n\n" |
| f"Question: {ex['question']}\n\nAssistant:" |
| ) |
| answers = cuckoo_extract(prompt, tokenizer, model, device) |
| pred = answers[0] if answers else "" |
|
|
| if squad_exact_match(pred, ex["answers"]["text"]): |
| correct += 1 |
|
|
| if (idx + 1) % 50 == 0: |
| print(f" Cuckoo SQuAD: {idx+1}/{len(ds)}, acc so far={correct/(idx+1):.4f}") |
|
|
| return {"accuracy": correct / len(ds), "n": len(ds), "correct": correct} |
|
|
|
|
|
|
| |
|
|
|
|
| def load_llm(model_name, device): |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, torch_dtype=torch.float16, trust_remote_code=True, |
| ).to(device) |
| model.eval() |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| return tokenizer, model |
|
|
|
|
| def llm_generate(prompt, tokenizer, model, device, max_new=128): |
| """Generate text from a prompt using the LLM.""" |
| |
| try: |
| messages = [{"role": "user", "content": prompt}] |
| text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True) |
| except Exception: |
| text = prompt |
|
|
| inputs = tokenizer(text, return_tensors="pt", truncation=True, |
| max_length=2048).to(device) |
| with torch.no_grad(): |
| out = model.generate( |
| **inputs, max_new_tokens=max_new, do_sample=False, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| new_tokens = out[0][inputs.input_ids.shape[1]:] |
| return tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
|
|
|
|
| NER_PROMPT_TEMPLATE = """\ |
| Extract all named entities from the text below. |
| For each entity, output one line in the format: entity_text | TYPE |
| TYPE must be exactly one of: PER, ORG, LOC, MISC |
| If there are no entities, output: NONE |
| |
| Example: |
| Text: EU rejects German call to boycott British lamb . |
| Entities: |
| EU | ORG |
| German | MISC |
| British | MISC |
| |
| Text: {sentence} |
| Entities:""" |
|
|
|
|
| def parse_ner_output(text): |
| """Parse LLM NER output into a set of (entity_text, label) tuples.""" |
| spans = set() |
| for line in text.strip().split("\n"): |
| line = line.strip().lstrip("- ").strip() |
| if "|" not in line: |
| continue |
| parts = line.split("|") |
| entity = parts[0].strip() |
| label = parts[-1].strip().upper() |
| if label in ("PER", "ORG", "LOC", "MISC") and entity and entity != "NONE": |
| spans.add((entity, label)) |
| return spans |
|
|
|
|
| def run_conll_uie(model_name, limit): |
| ds = load_dataset("tner/conll2003", split="test") |
| if limit > 0: |
| ds = ds.select(range(min(limit, len(ds)))) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| tokenizer, model = load_llm(model_name, device) |
|
|
| all_pred, all_gold = [], [] |
|
|
| for idx, ex in enumerate(ds): |
| tokens = ex["tokens"] |
| sentence = " ".join(tokens) |
| gold = gold_text_spans_for(tokens, ex["tags"]) |
|
|
| prompt = NER_PROMPT_TEMPLATE.format(sentence=sentence) |
| output = llm_generate(prompt, tokenizer, model, device, max_new=128) |
| pred = parse_ner_output(output) |
|
|
| all_pred.append(pred) |
| all_gold.append(gold) |
|
|
| if (idx + 1) % 50 == 0: |
| print(f" UIE CoNLL: {idx+1}/{len(ds)}") |
|
|
| return micro_f1(all_pred, all_gold) |
|
|
|
|
| QA_PROMPT_TEMPLATE = """\ |
| Answer the question using only a short span from the context. Output ONLY the answer span, nothing else. |
| |
| Context: {context} |
| Question: {question} |
| Answer:""" |
|
|
|
|
| def run_squad_uie(model_name, limit): |
| ds = load_dataset("rajpurkar/squad", split="validation") |
| if limit > 0: |
| ds = ds.select(range(min(limit, len(ds)))) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| tokenizer, model = load_llm(model_name, device) |
|
|
| correct = 0 |
| for idx, ex in enumerate(ds): |
| prompt = QA_PROMPT_TEMPLATE.format( |
| context=ex["context"], question=ex["question"]) |
| answer = llm_generate(prompt, tokenizer, model, device, max_new=32) |
| |
| answer = answer.split("\n")[0].strip() |
|
|
| if squad_exact_match(answer, ex["answers"]["text"]): |
| correct += 1 |
|
|
| if (idx + 1) % 50 == 0: |
| print(f" UIE SQuAD: {idx+1}/{len(ds)}, acc so far={correct/(idx+1):.4f}") |
|
|
| return {"accuracy": correct / len(ds), "n": len(ds), "correct": correct} |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--method", choices=["neural_ie", "uie", "cuckoo"], |
| required=True) |
| parser.add_argument("--task", choices=["conll", "squad"], required=True) |
| parser.add_argument("--model", required=True) |
| parser.add_argument("--limit", type=int, default=200) |
| parser.add_argument("--batch_size", type=int, default=32) |
| args = parser.parse_args() |
|
|
| tag = f"[{args.method.upper()} - {args.task.upper()}]" |
|
|
| if args.method == "neural_ie": |
| if args.task == "squad": |
| print(f"{tag} Skipped (Neural IE does not support SQuAD)") |
| return |
| result = run_conll_neural_ie(args.model, args.limit, args.batch_size) |
|
|
| elif args.method == "cuckoo": |
| if args.task == "conll": |
| result = run_conll_cuckoo(args.model, args.limit) |
| else: |
| result = run_squad_cuckoo(args.model, args.limit) |
|
|
| elif args.method == "uie": |
| if args.task == "conll": |
| result = run_conll_uie(args.model, args.limit) |
| else: |
| result = run_squad_uie(args.model, args.limit) |
|
|
| print(f"{tag} {result}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|