| |
| """ |
| Coreference Resolution with Qwen3-8B-Coref-NER |
| |
| Usage: |
| uv run inference.py input.txt |
| uv run inference.py input.txt --output resolved.txt |
| uv run inference.py input.txt --html report.html |
| """ |
|
|
| import argparse |
| import torch |
| import re |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
|
|
| MODEL_ID = "wjbmattingly/Qwen3-8B-Coref-NER" |
| BASE_MODEL = "Qwen/Qwen3-8B" |
|
|
|
|
| def parse_entity_mappings(response): |
| """Parse model response to extract resolved text and entity mappings.""" |
| if "NEW ENTITY MAPPINGS:" in response: |
| parts = response.split("NEW ENTITY MAPPINGS:") |
| resolved_text = parts[0].strip() |
| mappings_text = parts[1].strip() if len(parts) > 1 else "" |
| |
| entities = {} |
| for line in mappings_text.split("\n"): |
| line = line.strip() |
| if line.startswith("-"): |
| match = re.match(r'-\s*([^:]+):\s*\[([^\]]*)\]', line) |
| if match: |
| entity_name = match.group(1).strip() |
| variants = re.findall(r'"([^"]*)"', match.group(2)) |
| if variants: |
| entities[entity_name] = variants |
| return resolved_text, entities |
| return response.strip(), {} |
|
|
|
|
| def format_entities_for_prompt(entities): |
| """Format known entities for the prompt.""" |
| lines = ["Entities and their possible references:"] |
| for entity_name, variants in entities.items(): |
| variants_str = ", ".join(f'"{v}"' for v in variants) |
| lines.append(f"- {entity_name}: [{variants_str}]") |
| return "\n".join(lines) |
|
|
|
|
| def load_model(): |
| """Load the model and tokenizer.""" |
| print(f"Loading {BASE_MODEL}...") |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
| model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| print(f"Loading adapter {MODEL_ID}...") |
| model = PeftModel.from_pretrained(model, MODEL_ID) |
| return model, tokenizer |
|
|
|
|
| def process_document(text, model, tokenizer, max_new_tokens=2048): |
| """Process document paragraph by paragraph with entity tracking.""" |
| paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] |
| resolved_paragraphs = [] |
| paragraph_entities = [] |
| cumulative_entities = {} |
| |
| for i, paragraph in enumerate(paragraphs): |
| print(f"Processing paragraph {i+1}/{len(paragraphs)}...") |
| |
| |
| if i == 0: |
| prompt = f"Resolve all pronouns in this text, replacing them with the full entity names. Also identify any entity references you find.\n\n{paragraph}" |
| else: |
| context = "\n\n".join(resolved_paragraphs[max(0, i-2):i]) |
| if cumulative_entities: |
| known_str = format_entities_for_prompt(cumulative_entities) |
| prompt = f"Known {known_str}\n\nGiven this context of preceding text (already resolved):\n\n{context}\n\nResolve all pronouns in this paragraph using the known entities. Also identify any NEW entity references:\n\n{paragraph}" |
| else: |
| prompt = f"Given this context of preceding text (already resolved):\n\n{context}\n\nResolve all pronouns in this paragraph. Also identify any NEW entity references:\n\n{paragraph}" |
| |
| messages = [{"role": "user", "content": prompt}] |
| input_text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True, enable_thinking=False |
| ) |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| |
| response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() |
| |
| |
| response = re.sub(r'<think>.*?</think>\s*', '', response, flags=re.DOTALL) |
| |
| |
| resolved, new_entities = parse_entity_mappings(response) |
| resolved_paragraphs.append(resolved) |
| paragraph_entities.append(new_entities) |
| |
| |
| for entity_name, variants in new_entities.items(): |
| if entity_name not in cumulative_entities: |
| cumulative_entities[entity_name] = set() |
| cumulative_entities[entity_name].update(variants) |
| |
| if new_entities: |
| print(f" New entities: {new_entities}") |
| |
| |
| cumulative_entities = {k: list(v) for k, v in cumulative_entities.items()} |
| |
| return { |
| "paragraphs": paragraphs, |
| "resolved_paragraphs": resolved_paragraphs, |
| "paragraph_entities": paragraph_entities, |
| "cumulative_entities": cumulative_entities, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Coreference Resolution") |
| parser.add_argument("input", help="Input text file") |
| parser.add_argument("-o", "--output", help="Output file for resolved text") |
| parser.add_argument("--max-tokens", type=int, default=2048, help="Max tokens to generate") |
| args = parser.parse_args() |
| |
| |
| with open(args.input, "r") as f: |
| text = f.read().strip() |
| |
| |
| model, tokenizer = load_model() |
| |
| |
| result = process_document(text, model, tokenizer, args.max_tokens) |
| |
| |
| resolved_text = "\n\n".join(result["resolved_paragraphs"]) |
| |
| print("\n" + "="*60) |
| print("RESOLVED TEXT:") |
| print("="*60) |
| print(resolved_text) |
| |
| print("\n" + "="*60) |
| print("ENTITIES FOUND:") |
| print("="*60) |
| for entity, variants in result["cumulative_entities"].items(): |
| print(f" {entity}: {variants}") |
| |
| if args.output: |
| with open(args.output, "w") as f: |
| f.write(resolved_text) |
| print(f"\nSaved to {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|