CSE_261 / hw2_run.py
nancyH's picture
Upload folder using huggingface_hub
04594d9 verified
#!/usr/bin/env python3
import argparse
import re
import string
from typing import List, Tuple
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForCausalLM,
)
# BIO utilities
def bio_to_spans(tags: List[str]) -> List[Tuple[int, int, str]]:
spans = []
start = None
label = None
def close(i):
nonlocal start, label
if start is not None:
spans.append((start, i, label))
start = None
label = None
for i, t in enumerate(tags):
if t == "O":
close(i)
elif t.startswith("B-"):
close(i)
start = i
label = t[2:]
elif t.startswith("I-"):
if start is None:
start = i
label = t[2:]
else:
close(i)
close(len(tags))
return spans
def micro_f1(pred_spans_list, gold_spans_list):
tp = fp = fn = 0
for p, g in zip(pred_spans_list, gold_spans_list):
p_set = set(p)
g_set = set(g)
tp += len(p_set & g_set)
fp += len(p_set - g_set)
fn += len(g_set - p_set)
precision = tp / (tp + fp) if tp + fp > 0 else 0
recall = tp / (tp + fn) if tp + fn > 0 else 0
f1 = (2 * precision * recall / (precision + recall)
if precision + recall > 0 else 0)
return {
"precision": precision,
"recall": recall,
"f1": f1,
"tp": tp,
"fp": fp,
"fn": fn,
}
# CoNLL constants
CONLL_ID2TAG = {
0: "O", 1: "B-ORG", 2: "B-MISC", 3: "B-PER", 4: "I-PER",
5: "B-LOC", 6: "I-ORG", 7: "I-MISC", 8: "I-LOC",
}
# tner model uses verbose label names
MODEL_LABEL_MAP = {
"O": "O",
"B-organization": "B-ORG", "I-organization": "I-ORG",
"B-person": "B-PER", "I-person": "I-PER",
"B-location": "B-LOC", "I-location": "I-LOC",
"B-other": "B-MISC", "I-other": "I-MISC",
}
ENTITY_TYPE_PROMPTS = {
"PER": "What persons are mentioned?",
"ORG": "What organizations are mentioned?",
"LOC": "What locations are mentioned?",
"MISC": "What miscellaneous entities (nationalities, events, etc.) are mentioned?",
}
def gold_text_spans_for(tokens, tag_ids):
"""Convert CoNLL integer tags to a set of (text, label) tuples."""
tags = [CONLL_ID2TAG[t] for t in tag_ids]
spans = bio_to_spans(tags)
return set((" ".join(tokens[s:e]), lbl) for s, e, lbl in spans)
# SQuAD utilities
def normalize_answer(s):
def lower(text): return text.lower()
def remove_punc(text): return "".join(
ch for ch in text if ch not in set(string.punctuation))
def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text): return " ".join(text.split())
return white_space_fix(remove_articles(remove_punc(lower(s))))
def squad_exact_match(pred, golds):
pred_norm = normalize_answer(pred)
return any(pred_norm == normalize_answer(g) for g in golds)
# Neural IE (token classification)
def run_conll_neural_ie(model_name, limit, batch_size):
ds = load_dataset("tner/conll2003", split="test")
if limit > 0:
ds = ds.select(range(min(limit, len(ds))))
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(model_name).to(device)
model.eval()
id2label = model.config.id2label
all_pred, all_gold = [], []
for i in range(0, len(ds), batch_size):
batch = ds[i:i + batch_size]
tokens_batch = batch["tokens"]
gold_tags_batch = batch["tags"]
enc = tokenizer(
tokens_batch, is_split_into_words=True,
padding=True, truncation=True, return_tensors="pt",
).to(device)
with torch.no_grad():
preds = torch.argmax(model(**enc).logits, dim=-1)
for b in range(len(tokens_batch)):
word_ids = enc.word_ids(batch_index=b)
pred_tags, cur = [], None
for idx, wid in enumerate(word_ids):
if wid is None:
continue
if wid != cur:
cur = wid
raw = id2label[preds[b][idx].item()]
pred_tags.append(MODEL_LABEL_MAP.get(raw, raw))
gold_tags = [CONLL_ID2TAG[t] for t in gold_tags_batch[b]]
all_pred.append(bio_to_spans(pred_tags))
all_gold.append(bio_to_spans(gold_tags))
return micro_f1(all_pred, all_gold)
# Cuckoo (next-token extraction)
def cuckoo_extract(prompt, tokenizer, model, device):
"""Return list of extracted text spans from prompt."""
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
max_length=512).to(device)
with torch.no_grad():
tags = model(**inputs).logits[0].argmax(-1).cpu().tolist()
# 0 = begin-of-span, 1 = continuation
seqs = []
i = 0
while i < len(tags):
if tags[i] == 0:
s = i
e = i
i += 1
while i < len(tags) and tags[i] == 1:
e = i
i += 1
seqs.append((s, e + 1))
else:
i += 1
return [
tokenizer.decode(inputs.input_ids[0, s:e]).strip()
for s, e in seqs
]
def run_conll_cuckoo(model_name, limit):
ds = load_dataset("tner/conll2003", split="test")
if limit > 0:
ds = ds.select(range(min(limit, len(ds))))
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name).to(device)
model.eval()
all_pred, all_gold = [], []
for idx, ex in enumerate(ds):
tokens = ex["tokens"]
sentence = " ".join(tokens)
gold = gold_text_spans_for(tokens, ex["tags"])
pred = set()
for label, question in ENTITY_TYPE_PROMPTS.items():
prompt = f"User:\n\n{sentence}\n\nQuestion: {question}\n\nAssistant:"
for ent in cuckoo_extract(prompt, tokenizer, model, device):
if ent:
pred.add((ent, label))
all_pred.append(pred)
all_gold.append(gold)
if (idx + 1) % 50 == 0:
print(f" Cuckoo CoNLL: {idx+1}/{len(ds)}")
return micro_f1(all_pred, all_gold)
def run_squad_cuckoo(model_name, limit):
ds = load_dataset("rajpurkar/squad", split="validation")
if limit > 0:
ds = ds.select(range(min(limit, len(ds))))
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name).to(device)
model.eval()
correct = 0
for idx, ex in enumerate(ds):
prompt = (
f"User:\n\n{ex['context']}\n\n"
f"Question: {ex['question']}\n\nAssistant:"
)
answers = cuckoo_extract(prompt, tokenizer, model, device)
pred = answers[0] if answers else ""
if squad_exact_match(pred, ex["answers"]["text"]):
correct += 1
if (idx + 1) % 50 == 0:
print(f" Cuckoo SQuAD: {idx+1}/{len(ds)}, acc so far={correct/(idx+1):.4f}")
return {"accuracy": correct / len(ds), "n": len(ds), "correct": correct}
# UIE (prompting a generative LM)
def load_llm(model_name, device):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, trust_remote_code=True,
).to(device)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer, model
def llm_generate(prompt, tokenizer, model, device, max_new=128):
"""Generate text from a prompt using the LLM."""
# Try chat template first, fall back to raw prompt
try:
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
except Exception:
text = prompt
inputs = tokenizer(text, return_tensors="pt", truncation=True,
max_length=2048).to(device)
with torch.no_grad():
out = model.generate(
**inputs, max_new_tokens=max_new, do_sample=False,
pad_token_id=tokenizer.pad_token_id,
)
new_tokens = out[0][inputs.input_ids.shape[1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
NER_PROMPT_TEMPLATE = """\
Extract all named entities from the text below.
For each entity, output one line in the format: entity_text | TYPE
TYPE must be exactly one of: PER, ORG, LOC, MISC
If there are no entities, output: NONE
Example:
Text: EU rejects German call to boycott British lamb .
Entities:
EU | ORG
German | MISC
British | MISC
Text: {sentence}
Entities:"""
def parse_ner_output(text):
"""Parse LLM NER output into a set of (entity_text, label) tuples."""
spans = set()
for line in text.strip().split("\n"):
line = line.strip().lstrip("- ").strip()
if "|" not in line:
continue
parts = line.split("|")
entity = parts[0].strip()
label = parts[-1].strip().upper()
if label in ("PER", "ORG", "LOC", "MISC") and entity and entity != "NONE":
spans.add((entity, label))
return spans
def run_conll_uie(model_name, limit):
ds = load_dataset("tner/conll2003", split="test")
if limit > 0:
ds = ds.select(range(min(limit, len(ds))))
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer, model = load_llm(model_name, device)
all_pred, all_gold = [], []
for idx, ex in enumerate(ds):
tokens = ex["tokens"]
sentence = " ".join(tokens)
gold = gold_text_spans_for(tokens, ex["tags"])
prompt = NER_PROMPT_TEMPLATE.format(sentence=sentence)
output = llm_generate(prompt, tokenizer, model, device, max_new=128)
pred = parse_ner_output(output)
all_pred.append(pred)
all_gold.append(gold)
if (idx + 1) % 50 == 0:
print(f" UIE CoNLL: {idx+1}/{len(ds)}")
return micro_f1(all_pred, all_gold)
QA_PROMPT_TEMPLATE = """\
Answer the question using only a short span from the context. Output ONLY the answer span, nothing else.
Context: {context}
Question: {question}
Answer:"""
def run_squad_uie(model_name, limit):
ds = load_dataset("rajpurkar/squad", split="validation")
if limit > 0:
ds = ds.select(range(min(limit, len(ds))))
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer, model = load_llm(model_name, device)
correct = 0
for idx, ex in enumerate(ds):
prompt = QA_PROMPT_TEMPLATE.format(
context=ex["context"], question=ex["question"])
answer = llm_generate(prompt, tokenizer, model, device, max_new=32)
# Take only first line
answer = answer.split("\n")[0].strip()
if squad_exact_match(answer, ex["answers"]["text"]):
correct += 1
if (idx + 1) % 50 == 0:
print(f" UIE SQuAD: {idx+1}/{len(ds)}, acc so far={correct/(idx+1):.4f}")
return {"accuracy": correct / len(ds), "n": len(ds), "correct": correct}
# ==============================
# Main
# ==============================
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--method", choices=["neural_ie", "uie", "cuckoo"],
required=True)
parser.add_argument("--task", choices=["conll", "squad"], required=True)
parser.add_argument("--model", required=True)
parser.add_argument("--limit", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=32)
args = parser.parse_args()
tag = f"[{args.method.upper()} - {args.task.upper()}]"
if args.method == "neural_ie":
if args.task == "squad":
print(f"{tag} Skipped (Neural IE does not support SQuAD)")
return
result = run_conll_neural_ie(args.model, args.limit, args.batch_size)
elif args.method == "cuckoo":
if args.task == "conll":
result = run_conll_cuckoo(args.model, args.limit)
else:
result = run_squad_cuckoo(args.model, args.limit)
elif args.method == "uie":
if args.task == "conll":
result = run_conll_uie(args.model, args.limit)
else:
result = run_squad_uie(args.model, args.limit)
print(f"{tag} {result}")
if __name__ == "__main__":
main()