|
|
| import torch |
| import numpy as np |
| import re |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, PreTrainedModel, PretrainedConfig |
|
|
| class SemanticHighlighterConfig(PretrainedConfig): |
| model_type = "semantic_highlighter" |
|
|
| def __init__(self, base_model_name="BAAI/bge-reranker-base", **kwargs): |
| super().__init__(**kwargs) |
| self.base_model_name = base_model_name |
|
|
| class SemanticHighlighter(PreTrainedModel): |
| """Arabic Semantic Highlighter - highlights relevant sentences given a query.""" |
|
|
| config_class = SemanticHighlighterConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = AutoModelForSequenceClassification.from_pretrained( |
| config._name_or_path, |
| num_labels=1, |
| ignore_mismatched_sizes=True |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) |
| self.device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None): |
| return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
|
|
| def _split_sentences(self, text, language="ar"): |
| """Split text into sentences.""" |
| if language == "ar": |
| sentences = re.split(r'[.؟!。\n]', text) |
| else: |
| sentences = re.split(r'[.?!\n]', text) |
| return [s.strip() for s in sentences if s.strip() and len(s.strip()) > 5] |
|
|
| def _score_sentence(self, question, sentence): |
| """Score a single sentence against the question.""" |
| inputs = self.tokenizer( |
| question, sentence, |
| truncation=True, |
| max_length=256, |
| padding='max_length', |
| return_tensors='pt' |
| ).to(self.device_type) |
|
|
| with torch.no_grad(): |
| logit = self.model(**inputs).logits.squeeze().item() |
| return 1 / (1 + np.exp(-logit)) |
|
|
| def process(self, question, context, threshold=0.5, language="auto", return_sentence_metrics=False): |
| """ |
| Process question and context to highlight relevant sentences. |
| |
| Args: |
| question: The query/question string |
| context: The context text to search for relevant sentences |
| threshold: Minimum probability to consider a sentence relevant (default: 0.5) |
| language: Language of the text ("ar", "en", or "auto" for detection) |
| return_sentence_metrics: If True, include sentence probabilities in output |
| |
| Returns: |
| dict with keys: |
| - highlighted_sentences: List of sentences above threshold |
| - all_sentences: All sentences in the context |
| - sentence_probabilities: (if return_sentence_metrics=True) probability scores |
| """ |
| self.model.to(self.device_type) |
| self.model.eval() |
|
|
| |
| if language == "auto": |
| arabic_chars = len(re.findall(r'[\u0600-\u06FF]', context)) |
| language = "ar" if arabic_chars > len(context) * 0.3 else "en" |
|
|
| |
| sentences = self._split_sentences(context, language) |
|
|
| |
| probabilities = [] |
| highlighted = [] |
|
|
| for sentence in sentences: |
| prob = self._score_sentence(question, sentence) |
| probabilities.append(prob) |
| if prob >= threshold: |
| highlighted.append(sentence) |
|
|
| result = { |
| "highlighted_sentences": highlighted, |
| "all_sentences": sentences, |
| } |
|
|
| if return_sentence_metrics: |
| result["sentence_probabilities"] = probabilities |
|
|
| return result |
|
|