arabic-semantic-highlighter / modeling_highlighter.py
HeshamHaroon's picture
Upload Arabic Semantic Highlighter model
4b4c26b verified
import torch
import numpy as np
import re
from transformers import AutoModelForSequenceClassification, AutoTokenizer, PreTrainedModel, PretrainedConfig
class SemanticHighlighterConfig(PretrainedConfig):
model_type = "semantic_highlighter"
def __init__(self, base_model_name="BAAI/bge-reranker-base", **kwargs):
super().__init__(**kwargs)
self.base_model_name = base_model_name
class SemanticHighlighter(PreTrainedModel):
"""Arabic Semantic Highlighter - highlights relevant sentences given a query."""
config_class = SemanticHighlighterConfig
def __init__(self, config):
super().__init__(config)
self.model = AutoModelForSequenceClassification.from_pretrained(
config._name_or_path,
num_labels=1,
ignore_mismatched_sizes=True
)
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
self.device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def forward(self, input_ids, attention_mask=None, labels=None):
return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
def _split_sentences(self, text, language="ar"):
"""Split text into sentences."""
if language == "ar":
sentences = re.split(r'[.؟!。\n]', text)
else:
sentences = re.split(r'[.?!\n]', text)
return [s.strip() for s in sentences if s.strip() and len(s.strip()) > 5]
def _score_sentence(self, question, sentence):
"""Score a single sentence against the question."""
inputs = self.tokenizer(
question, sentence,
truncation=True,
max_length=256,
padding='max_length',
return_tensors='pt'
).to(self.device_type)
with torch.no_grad():
logit = self.model(**inputs).logits.squeeze().item()
return 1 / (1 + np.exp(-logit))
def process(self, question, context, threshold=0.5, language="auto", return_sentence_metrics=False):
"""
Process question and context to highlight relevant sentences.
Args:
question: The query/question string
context: The context text to search for relevant sentences
threshold: Minimum probability to consider a sentence relevant (default: 0.5)
language: Language of the text ("ar", "en", or "auto" for detection)
return_sentence_metrics: If True, include sentence probabilities in output
Returns:
dict with keys:
- highlighted_sentences: List of sentences above threshold
- all_sentences: All sentences in the context
- sentence_probabilities: (if return_sentence_metrics=True) probability scores
"""
self.model.to(self.device_type)
self.model.eval()
# Auto-detect language
if language == "auto":
arabic_chars = len(re.findall(r'[\u0600-\u06FF]', context))
language = "ar" if arabic_chars > len(context) * 0.3 else "en"
# Split into sentences
sentences = self._split_sentences(context, language)
# Score each sentence
probabilities = []
highlighted = []
for sentence in sentences:
prob = self._score_sentence(question, sentence)
probabilities.append(prob)
if prob >= threshold:
highlighted.append(sentence)
result = {
"highlighted_sentences": highlighted,
"all_sentences": sentences,
}
if return_sentence_metrics:
result["sentence_probabilities"] = probabilities
return result