Qwen3-8B-Coref-NER / inference.py
wjbmattingly's picture
Upload folder using huggingface_hub
a2f094f verified
#!/usr/bin/env python3
"""
Coreference Resolution with Qwen3-8B-Coref-NER
Usage:
uv run inference.py input.txt
uv run inference.py input.txt --output resolved.txt
uv run inference.py input.txt --html report.html
"""
import argparse
import torch
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
MODEL_ID = "wjbmattingly/Qwen3-8B-Coref-NER"
BASE_MODEL = "Qwen/Qwen3-8B"
def parse_entity_mappings(response):
"""Parse model response to extract resolved text and entity mappings."""
if "NEW ENTITY MAPPINGS:" in response:
parts = response.split("NEW ENTITY MAPPINGS:")
resolved_text = parts[0].strip()
mappings_text = parts[1].strip() if len(parts) > 1 else ""
entities = {}
for line in mappings_text.split("\n"):
line = line.strip()
if line.startswith("-"):
match = re.match(r'-\s*([^:]+):\s*\[([^\]]*)\]', line)
if match:
entity_name = match.group(1).strip()
variants = re.findall(r'"([^"]*)"', match.group(2))
if variants:
entities[entity_name] = variants
return resolved_text, entities
return response.strip(), {}
def format_entities_for_prompt(entities):
"""Format known entities for the prompt."""
lines = ["Entities and their possible references:"]
for entity_name, variants in entities.items():
variants_str = ", ".join(f'"{v}"' for v in variants)
lines.append(f"- {entity_name}: [{variants_str}]")
return "\n".join(lines)
def load_model():
"""Load the model and tokenizer."""
print(f"Loading {BASE_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
)
print(f"Loading adapter {MODEL_ID}...")
model = PeftModel.from_pretrained(model, MODEL_ID)
return model, tokenizer
def process_document(text, model, tokenizer, max_new_tokens=2048):
"""Process document paragraph by paragraph with entity tracking."""
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
resolved_paragraphs = []
paragraph_entities = []
cumulative_entities = {}
for i, paragraph in enumerate(paragraphs):
print(f"Processing paragraph {i+1}/{len(paragraphs)}...")
# Build prompt
if i == 0:
prompt = f"Resolve all pronouns in this text, replacing them with the full entity names. Also identify any entity references you find.\n\n{paragraph}"
else:
context = "\n\n".join(resolved_paragraphs[max(0, i-2):i])
if cumulative_entities:
known_str = format_entities_for_prompt(cumulative_entities)
prompt = f"Known {known_str}\n\nGiven this context of preceding text (already resolved):\n\n{context}\n\nResolve all pronouns in this paragraph using the known entities. Also identify any NEW entity references:\n\n{paragraph}"
else:
prompt = f"Given this context of preceding text (already resolved):\n\n{context}\n\nResolve all pronouns in this paragraph. Also identify any NEW entity references:\n\n{paragraph}"
messages = [{"role": "user", "content": prompt}]
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
# Clean up
response = re.sub(r'<think>.*?</think>\s*', '', response, flags=re.DOTALL)
# Parse response
resolved, new_entities = parse_entity_mappings(response)
resolved_paragraphs.append(resolved)
paragraph_entities.append(new_entities)
# Update cumulative entities
for entity_name, variants in new_entities.items():
if entity_name not in cumulative_entities:
cumulative_entities[entity_name] = set()
cumulative_entities[entity_name].update(variants)
if new_entities:
print(f" New entities: {new_entities}")
# Convert sets to lists
cumulative_entities = {k: list(v) for k, v in cumulative_entities.items()}
return {
"paragraphs": paragraphs,
"resolved_paragraphs": resolved_paragraphs,
"paragraph_entities": paragraph_entities,
"cumulative_entities": cumulative_entities,
}
def main():
parser = argparse.ArgumentParser(description="Coreference Resolution")
parser.add_argument("input", help="Input text file")
parser.add_argument("-o", "--output", help="Output file for resolved text")
parser.add_argument("--max-tokens", type=int, default=2048, help="Max tokens to generate")
args = parser.parse_args()
# Load input
with open(args.input, "r") as f:
text = f.read().strip()
# Load model
model, tokenizer = load_model()
# Process
result = process_document(text, model, tokenizer, args.max_tokens)
# Output
resolved_text = "\n\n".join(result["resolved_paragraphs"])
print("\n" + "="*60)
print("RESOLVED TEXT:")
print("="*60)
print(resolved_text)
print("\n" + "="*60)
print("ENTITIES FOUND:")
print("="*60)
for entity, variants in result["cumulative_entities"].items():
print(f" {entity}: {variants}")
if args.output:
with open(args.output, "w") as f:
f.write(resolved_text)
print(f"\nSaved to {args.output}")
if __name__ == "__main__":
main()