ChainBlind-HadithIsnadParser-CAMeLBERT / camelbert_duplets_biaffine_inference.py
Jehadoumer's picture
Upload 6 files
1a1970b verified
"""
CAMeLBERT-CA chain-blind inference: NER -> candidate duplets -> biaffine classifier -> <head>...<tail> output.
Run on raw isnad text (or list of words). Produces a single linearized string of the form
<head> A <tail> B <head> B <tail> C ... #hadith
for all (head, tail) pairs classified as positive by the biaffine layer.
Weights format:
- Local: best_model.pth (PyTorch state_dict; from CAMeLBERT-CA_duplets_biaffine_train.py).
- Hugging Face: model.safetensors (repo https://huggingface.co/Jehadoumer/ChainBlind-HadithIsnadParser-CAMeLBERT).
"""
import argparse
import os
import re
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
# ---------------------------------------------------------------------------
# Config (override via args or env)
# ---------------------------------------------------------------------------
# Hugging Face model: chain-blind CAMeLBERT-CA (NER + biaffine)
DEFAULT_HF_MODEL_ID = "Jehadoumer/ChainBlind-HadithIsnadParser-CAMeLBERT"
BERT_NAME = "CAMeL-Lab/bert-base-arabic-camelbert-ca"
MAX_LENGTH = 128
LABEL_MAP = {"O": 0, "B-NARRATION": 1, "I-NARRATION": 2, "B-NARRATOR": 3, "I-NARRATOR": 4}
# ---------------------------------------------------------------------------
# Model definitions (must match training script)
# ---------------------------------------------------------------------------
class BiaffineClassifier(nn.Module):
def __init__(self, hidden_size, num_labels, include_linear=True):
super(BiaffineClassifier, self).__init__()
self.include_linear = include_linear
self.U = nn.Parameter(torch.randn(hidden_size, hidden_size, num_labels))
if include_linear:
self.W = nn.Linear(2 * hidden_size, num_labels)
self.b = nn.Parameter(torch.randn(num_labels))
def forward(self, head_vector, tail_vector):
bilinear_score = torch.einsum('bi,ijk,bj->bk', head_vector, self.U, tail_vector)
if self.include_linear:
combined_vector = torch.cat([head_vector, tail_vector], dim=-1)
linear_score = self.W(combined_vector)
return bilinear_score + linear_score + self.b
return bilinear_score
class JointNERREModel(nn.Module):
def __init__(self, num_ner_labels, num_re_labels, hidden_size, include_linear=True):
super(JointNERREModel, self).__init__()
self.bert = BertModel.from_pretrained(BERT_NAME)
self.ner_classifier = nn.Linear(hidden_size, num_ner_labels)
self.head_fc = nn.Linear(hidden_size, hidden_size)
self.tail_fc = nn.Linear(hidden_size, hidden_size)
self.biaffine_classifier = BiaffineClassifier(hidden_size, num_re_labels, include_linear=include_linear)
self.tokenizer = BertTokenizer.from_pretrained(BERT_NAME)
def clean_entity_name(self, name):
if not isinstance(name, str):
raise TypeError(f"Expected a string for entity name, but got {type(name)} with value: {name}")
name = re.sub(r'[\"\'\[\]\(\)\,\.\d\:\-\_\،\؟\؛\!\@\#\$\%\^\&\*\+\=\{\}\|\<\>\/\\]', '', name)
name = re.sub(r'\s+', ' ', name).strip()
name = ''.join([c for c in name if '\u0600' <= c <= '\u06FF']).strip()
return name
def extract_entities(self, tokens, labels, sequence_output):
"""Extract narrator spans and their mean embeddings from token-level NER. sequence_output: [1, seq_len, hidden]."""
entities = []
entity_embeddings = []
current_entity = []
current_entity_embeddings = []
for idx, (token, label) in enumerate(zip(tokens, labels)):
if label == 3 or label == 4: # B-NARRATOR or I-NARRATOR
if current_entity:
current_entity.append(token)
current_entity_embeddings.append(sequence_output[0, idx, :])
else:
current_entity = [token]
current_entity_embeddings = [sequence_output[0, idx, :]]
elif current_entity:
entities.append(" ".join(current_entity))
entity_embeddings.append(torch.stack(current_entity_embeddings).mean(dim=0))
current_entity = []
current_entity_embeddings = []
if current_entity:
entities.append(" ".join(current_entity))
entity_embeddings.append(torch.stack(current_entity_embeddings).mean(dim=0))
return entities, entity_embeddings
def generate_entity_pairs(self, entities, embeddings):
entity_pairs = []
entity_pair_embeddings = []
for i, head in enumerate(entities):
for j, tail in enumerate(entities):
if i != j:
entity_pairs.append((head, tail))
entity_pair_embeddings.append((embeddings[i], embeddings[j]))
return entity_pairs, entity_pair_embeddings
def _tokenize_isnad_words(words, tokenizer, max_length=MAX_LENGTH):
"""Tokenize a list of words (isnad_sequence) into input_ids and attention_mask. Same logic as IsnadDataset."""
tokenized_input_ids = []
attention_mask = []
for word in words:
word_tokens = tokenizer.tokenize(word)
word_input_ids = tokenizer.convert_tokens_to_ids(word_tokens)
tokenized_input_ids.extend(word_input_ids)
attention_mask.extend([1] * len(word_input_ids))
tokenized_input_ids = tokenized_input_ids[:max_length - 2]
attention_mask = attention_mask[:max_length - 2]
tokenized_input_ids = [tokenizer.cls_token_id] + tokenized_input_ids + [tokenizer.sep_token_id]
attention_mask = [1] + attention_mask + [1]
padding_length = max_length - len(tokenized_input_ids)
tokenized_input_ids.extend([tokenizer.pad_token_id] * padding_length)
attention_mask.extend([0] * padding_length)
return (
torch.tensor([tokenized_input_ids], dtype=torch.long),
torch.tensor([attention_mask], dtype=torch.long),
)
def _tokenize_raw_text(text, tokenizer, max_length=MAX_LENGTH):
"""Tokenize raw string by splitting on whitespace into words, then subwords."""
words = text.split()
return _tokenize_isnad_words(words, tokenizer, max_length)
def _format_chain_blind(positive_pairs):
"""Format list of (head, tail) into <head> h <tail> t ... #hadith. Normalize entity display (remove ##)."""
parts = []
for head, tail in positive_pairs:
h = head.replace(" ##", "").strip()
t = tail.replace(" ##", "").strip()
if h and t:
parts.append(f"<head> {h} <tail> {t}")
if not parts:
return "<head> <tail> #hadith"
return " ".join(parts) + " #hadith"
@torch.no_grad()
def run_inference(model, input_ids, attention_mask, device):
"""
Run NER -> extract entities -> all pairs -> biaffine -> keep positive.
Returns list of (head, tail) strings for positive relations.
"""
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs.last_hidden_state
ner_logits = model.ner_classifier(sequence_output)
ner_predictions = torch.argmax(ner_logits, dim=-1)
tokens = model.tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
labels = ner_predictions[0].tolist()
entities, entity_embeddings = model.extract_entities(tokens, labels, sequence_output)
if not entities:
return []
entity_pairs, pair_embeddings = model.generate_entity_pairs(entities, entity_embeddings)
if not entity_pairs:
return []
head_embeddings = torch.stack([p[0] for p in pair_embeddings], dim=0)
tail_embeddings = torch.stack([p[1] for p in pair_embeddings], dim=0)
relation_logits = model.biaffine_classifier(head_embeddings, tail_embeddings)
re_predictions = torch.argmax(relation_logits, dim=-1)
positive_pairs = [
entity_pairs[i] for i in range(len(entity_pairs))
if re_predictions[i].item() == 1
]
return positive_pairs
def load_model(device, model_id=None, checkpoint_path=None):
"""Load JointNERREModel from Hugging Face (default) or from a local checkpoint file."""
model = JointNERREModel(num_ner_labels=5, num_re_labels=2, hidden_size=768)
if model_id:
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError("Loading from Hugging Face requires huggingface_hub: pip install huggingface_hub")
# Hub repo has model.safetensors (see https://huggingface.co/Jehadoumer/ChainBlind-HadithIsnadParser-CAMeLBERT). Fallback: pytorch_model.bin.
path = None
for filename in ("model.safetensors", "pytorch_model.bin"):
try:
path = hf_hub_download(repo_id=model_id, filename=filename)
break
except Exception:
continue
if path is None:
raise FileNotFoundError(f"No model weights (model.safetensors or pytorch_model.bin) in repo {model_id}")
if path.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict = load_file(path)
else:
state_dict = torch.load(path, map_location=device, weights_only=True)
# Keep only keys that exist in our model (ignore MLM head / extra keys from full BERT checkpoints)
model_keys = set(model.state_dict().keys())
state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
model.load_state_dict(state_dict, strict=False)
model.tokenizer = BertTokenizer.from_pretrained(model_id)
elif checkpoint_path and os.path.isfile(checkpoint_path):
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
model_keys = set(model.state_dict().keys())
state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
model.load_state_dict(state_dict, strict=False)
model.tokenizer = BertTokenizer.from_pretrained(BERT_NAME)
else:
raise ValueError("Provide --model-id (Hugging Face repo) or --checkpoint (path to best_model.pth)")
model.to(device)
model.eval()
return model
def main():
parser = argparse.ArgumentParser(description="CAMeLBERT-CA biaffine inference -> <head>...<tail> chain-blind output")
parser.add_argument("--model-id", type=str, default=DEFAULT_HF_MODEL_ID,
help="Hugging Face model id (default: %(default)s)")
parser.add_argument("--checkpoint", type=str, default=None,
help="Optional path to local best_model.pth (overrides --model-id if set)")
parser.add_argument("--input", type=str, default=None, help="Raw isnad text (space-separated words)")
parser.add_argument("--input-file", type=str, default=None, help="JSONL file with 'isnad_sequence' (list of words) per line")
parser.add_argument("--max-length", type=int, default=MAX_LENGTH, help="Max sequence length")
parser.add_argument("--no-cuda", action="store_true", help="Force CPU")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
model = load_model(
device,
model_id=None if args.checkpoint else args.model_id,
checkpoint_path=args.checkpoint,
)
tokenizer = model.tokenizer
if args.input_file:
import json
with open(args.input_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
item = json.loads(line)
words = item.get("isnad_sequence", item.get("isnad_text", "").split())
if isinstance(words, str):
words = words.split()
input_ids, attention_mask = _tokenize_isnad_words(words, tokenizer, args.max_length)
positive_pairs = run_inference(model, input_ids, attention_mask, device)
print(_format_chain_blind(positive_pairs))
return
if args.input:
input_ids, attention_mask = _tokenize_raw_text(args.input, tokenizer, args.max_length)
else:
# Example from README/paper
example = (
"حدثنا يوسف القاضي ، ثنا عبد الواحد بن غياث ، ثنا حماد بن سلمة ، "
"عن أيوب وعبيد الله بن عمر ، عن نافع ، عن ابن عمر قال : قال زيد بن ثابت"
)
words = example.split()
input_ids, attention_mask = _tokenize_isnad_words(words, tokenizer, args.max_length)
positive_pairs = run_inference(model, input_ids, attention_mask, device)
print(_format_chain_blind(positive_pairs))
if __name__ == "__main__":
main()