| """
|
| CAMeLBERT-CA chain-blind inference: NER -> candidate duplets -> biaffine classifier -> <head>...<tail> output.
|
|
|
| Run on raw isnad text (or list of words). Produces a single linearized string of the form
|
| <head> A <tail> B <head> B <tail> C ... #hadith
|
| for all (head, tail) pairs classified as positive by the biaffine layer.
|
|
|
| Weights format:
|
| - Local: best_model.pth (PyTorch state_dict; from CAMeLBERT-CA_duplets_biaffine_train.py).
|
| - Hugging Face: model.safetensors (repo https://huggingface.co/Jehadoumer/ChainBlind-HadithIsnadParser-CAMeLBERT).
|
| """
|
|
|
| import argparse
|
| import os
|
| import re
|
| import torch
|
| import torch.nn as nn
|
| from transformers import BertModel, BertTokenizer
|
|
|
|
|
|
|
|
|
|
|
| DEFAULT_HF_MODEL_ID = "Jehadoumer/ChainBlind-HadithIsnadParser-CAMeLBERT"
|
| BERT_NAME = "CAMeL-Lab/bert-base-arabic-camelbert-ca"
|
| MAX_LENGTH = 128
|
| LABEL_MAP = {"O": 0, "B-NARRATION": 1, "I-NARRATION": 2, "B-NARRATOR": 3, "I-NARRATOR": 4}
|
|
|
|
|
|
|
|
|
| class BiaffineClassifier(nn.Module):
|
| def __init__(self, hidden_size, num_labels, include_linear=True):
|
| super(BiaffineClassifier, self).__init__()
|
| self.include_linear = include_linear
|
| self.U = nn.Parameter(torch.randn(hidden_size, hidden_size, num_labels))
|
| if include_linear:
|
| self.W = nn.Linear(2 * hidden_size, num_labels)
|
| self.b = nn.Parameter(torch.randn(num_labels))
|
|
|
| def forward(self, head_vector, tail_vector):
|
| bilinear_score = torch.einsum('bi,ijk,bj->bk', head_vector, self.U, tail_vector)
|
| if self.include_linear:
|
| combined_vector = torch.cat([head_vector, tail_vector], dim=-1)
|
| linear_score = self.W(combined_vector)
|
| return bilinear_score + linear_score + self.b
|
| return bilinear_score
|
|
|
|
|
| class JointNERREModel(nn.Module):
|
| def __init__(self, num_ner_labels, num_re_labels, hidden_size, include_linear=True):
|
| super(JointNERREModel, self).__init__()
|
| self.bert = BertModel.from_pretrained(BERT_NAME)
|
| self.ner_classifier = nn.Linear(hidden_size, num_ner_labels)
|
| self.head_fc = nn.Linear(hidden_size, hidden_size)
|
| self.tail_fc = nn.Linear(hidden_size, hidden_size)
|
| self.biaffine_classifier = BiaffineClassifier(hidden_size, num_re_labels, include_linear=include_linear)
|
| self.tokenizer = BertTokenizer.from_pretrained(BERT_NAME)
|
|
|
| def clean_entity_name(self, name):
|
| if not isinstance(name, str):
|
| raise TypeError(f"Expected a string for entity name, but got {type(name)} with value: {name}")
|
| name = re.sub(r'[\"\'\[\]\(\)\,\.\d\:\-\_\،\؟\؛\!\@\#\$\%\^\&\*\+\=\{\}\|\<\>\/\\]', '', name)
|
| name = re.sub(r'\s+', ' ', name).strip()
|
| name = ''.join([c for c in name if '\u0600' <= c <= '\u06FF']).strip()
|
| return name
|
|
|
| def extract_entities(self, tokens, labels, sequence_output):
|
| """Extract narrator spans and their mean embeddings from token-level NER. sequence_output: [1, seq_len, hidden]."""
|
| entities = []
|
| entity_embeddings = []
|
| current_entity = []
|
| current_entity_embeddings = []
|
|
|
| for idx, (token, label) in enumerate(zip(tokens, labels)):
|
| if label == 3 or label == 4:
|
| if current_entity:
|
| current_entity.append(token)
|
| current_entity_embeddings.append(sequence_output[0, idx, :])
|
| else:
|
| current_entity = [token]
|
| current_entity_embeddings = [sequence_output[0, idx, :]]
|
| elif current_entity:
|
| entities.append(" ".join(current_entity))
|
| entity_embeddings.append(torch.stack(current_entity_embeddings).mean(dim=0))
|
| current_entity = []
|
| current_entity_embeddings = []
|
|
|
| if current_entity:
|
| entities.append(" ".join(current_entity))
|
| entity_embeddings.append(torch.stack(current_entity_embeddings).mean(dim=0))
|
|
|
| return entities, entity_embeddings
|
|
|
| def generate_entity_pairs(self, entities, embeddings):
|
| entity_pairs = []
|
| entity_pair_embeddings = []
|
| for i, head in enumerate(entities):
|
| for j, tail in enumerate(entities):
|
| if i != j:
|
| entity_pairs.append((head, tail))
|
| entity_pair_embeddings.append((embeddings[i], embeddings[j]))
|
| return entity_pairs, entity_pair_embeddings
|
|
|
|
|
| def _tokenize_isnad_words(words, tokenizer, max_length=MAX_LENGTH):
|
| """Tokenize a list of words (isnad_sequence) into input_ids and attention_mask. Same logic as IsnadDataset."""
|
| tokenized_input_ids = []
|
| attention_mask = []
|
|
|
| for word in words:
|
| word_tokens = tokenizer.tokenize(word)
|
| word_input_ids = tokenizer.convert_tokens_to_ids(word_tokens)
|
| tokenized_input_ids.extend(word_input_ids)
|
| attention_mask.extend([1] * len(word_input_ids))
|
|
|
| tokenized_input_ids = tokenized_input_ids[:max_length - 2]
|
| attention_mask = attention_mask[:max_length - 2]
|
|
|
| tokenized_input_ids = [tokenizer.cls_token_id] + tokenized_input_ids + [tokenizer.sep_token_id]
|
| attention_mask = [1] + attention_mask + [1]
|
|
|
| padding_length = max_length - len(tokenized_input_ids)
|
| tokenized_input_ids.extend([tokenizer.pad_token_id] * padding_length)
|
| attention_mask.extend([0] * padding_length)
|
|
|
| return (
|
| torch.tensor([tokenized_input_ids], dtype=torch.long),
|
| torch.tensor([attention_mask], dtype=torch.long),
|
| )
|
|
|
|
|
| def _tokenize_raw_text(text, tokenizer, max_length=MAX_LENGTH):
|
| """Tokenize raw string by splitting on whitespace into words, then subwords."""
|
| words = text.split()
|
| return _tokenize_isnad_words(words, tokenizer, max_length)
|
|
|
|
|
| def _format_chain_blind(positive_pairs):
|
| """Format list of (head, tail) into <head> h <tail> t ... #hadith. Normalize entity display (remove ##)."""
|
| parts = []
|
| for head, tail in positive_pairs:
|
| h = head.replace(" ##", "").strip()
|
| t = tail.replace(" ##", "").strip()
|
| if h and t:
|
| parts.append(f"<head> {h} <tail> {t}")
|
| if not parts:
|
| return "<head> <tail> #hadith"
|
| return " ".join(parts) + " #hadith"
|
|
|
|
|
| @torch.no_grad()
|
| def run_inference(model, input_ids, attention_mask, device):
|
| """
|
| Run NER -> extract entities -> all pairs -> biaffine -> keep positive.
|
| Returns list of (head, tail) strings for positive relations.
|
| """
|
| input_ids = input_ids.to(device)
|
| attention_mask = attention_mask.to(device)
|
|
|
| outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| sequence_output = outputs.last_hidden_state
|
| ner_logits = model.ner_classifier(sequence_output)
|
| ner_predictions = torch.argmax(ner_logits, dim=-1)
|
|
|
| tokens = model.tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
|
| labels = ner_predictions[0].tolist()
|
|
|
| entities, entity_embeddings = model.extract_entities(tokens, labels, sequence_output)
|
| if not entities:
|
| return []
|
|
|
| entity_pairs, pair_embeddings = model.generate_entity_pairs(entities, entity_embeddings)
|
| if not entity_pairs:
|
| return []
|
|
|
| head_embeddings = torch.stack([p[0] for p in pair_embeddings], dim=0)
|
| tail_embeddings = torch.stack([p[1] for p in pair_embeddings], dim=0)
|
| relation_logits = model.biaffine_classifier(head_embeddings, tail_embeddings)
|
| re_predictions = torch.argmax(relation_logits, dim=-1)
|
|
|
| positive_pairs = [
|
| entity_pairs[i] for i in range(len(entity_pairs))
|
| if re_predictions[i].item() == 1
|
| ]
|
| return positive_pairs
|
|
|
|
|
| def load_model(device, model_id=None, checkpoint_path=None):
|
| """Load JointNERREModel from Hugging Face (default) or from a local checkpoint file."""
|
| model = JointNERREModel(num_ner_labels=5, num_re_labels=2, hidden_size=768)
|
|
|
| if model_id:
|
| try:
|
| from huggingface_hub import hf_hub_download
|
| except ImportError:
|
| raise ImportError("Loading from Hugging Face requires huggingface_hub: pip install huggingface_hub")
|
|
|
| path = None
|
| for filename in ("model.safetensors", "pytorch_model.bin"):
|
| try:
|
| path = hf_hub_download(repo_id=model_id, filename=filename)
|
| break
|
| except Exception:
|
| continue
|
| if path is None:
|
| raise FileNotFoundError(f"No model weights (model.safetensors or pytorch_model.bin) in repo {model_id}")
|
| if path.endswith(".safetensors"):
|
| from safetensors.torch import load_file
|
| state_dict = load_file(path)
|
| else:
|
| state_dict = torch.load(path, map_location=device, weights_only=True)
|
|
|
| model_keys = set(model.state_dict().keys())
|
| state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
|
| model.load_state_dict(state_dict, strict=False)
|
| model.tokenizer = BertTokenizer.from_pretrained(model_id)
|
| elif checkpoint_path and os.path.isfile(checkpoint_path):
|
| state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
|
| model_keys = set(model.state_dict().keys())
|
| state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
|
| model.load_state_dict(state_dict, strict=False)
|
| model.tokenizer = BertTokenizer.from_pretrained(BERT_NAME)
|
| else:
|
| raise ValueError("Provide --model-id (Hugging Face repo) or --checkpoint (path to best_model.pth)")
|
|
|
| model.to(device)
|
| model.eval()
|
| return model
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="CAMeLBERT-CA biaffine inference -> <head>...<tail> chain-blind output")
|
| parser.add_argument("--model-id", type=str, default=DEFAULT_HF_MODEL_ID,
|
| help="Hugging Face model id (default: %(default)s)")
|
| parser.add_argument("--checkpoint", type=str, default=None,
|
| help="Optional path to local best_model.pth (overrides --model-id if set)")
|
| parser.add_argument("--input", type=str, default=None, help="Raw isnad text (space-separated words)")
|
| parser.add_argument("--input-file", type=str, default=None, help="JSONL file with 'isnad_sequence' (list of words) per line")
|
| parser.add_argument("--max-length", type=int, default=MAX_LENGTH, help="Max sequence length")
|
| parser.add_argument("--no-cuda", action="store_true", help="Force CPU")
|
| args = parser.parse_args()
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
| model = load_model(
|
| device,
|
| model_id=None if args.checkpoint else args.model_id,
|
| checkpoint_path=args.checkpoint,
|
| )
|
| tokenizer = model.tokenizer
|
|
|
| if args.input_file:
|
| import json
|
| with open(args.input_file, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if not line:
|
| continue
|
| item = json.loads(line)
|
| words = item.get("isnad_sequence", item.get("isnad_text", "").split())
|
| if isinstance(words, str):
|
| words = words.split()
|
| input_ids, attention_mask = _tokenize_isnad_words(words, tokenizer, args.max_length)
|
| positive_pairs = run_inference(model, input_ids, attention_mask, device)
|
| print(_format_chain_blind(positive_pairs))
|
| return
|
|
|
| if args.input:
|
| input_ids, attention_mask = _tokenize_raw_text(args.input, tokenizer, args.max_length)
|
| else:
|
|
|
| example = (
|
| "حدثنا يوسف القاضي ، ثنا عبد الواحد بن غياث ، ثنا حماد بن سلمة ، "
|
| "عن أيوب وعبيد الله بن عمر ، عن نافع ، عن ابن عمر قال : قال زيد بن ثابت"
|
| )
|
| words = example.split()
|
| input_ids, attention_mask = _tokenize_isnad_words(words, tokenizer, args.max_length)
|
|
|
| positive_pairs = run_inference(model, input_ids, attention_mask, device)
|
| print(_format_chain_blind(positive_pairs))
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|