squad2-qa / src /models /sentence_embedding_model.py
Kimis Perros
Initial deployment
461f64f
"""
Contains a simple baseline for the QA system.
"""
import spacy
from typing import Dict, Optional, List
from src.config.model_configs import SentenceEmbeddingModelConfig
from src.models.base_qa_model import QAModel
from sentence_transformers import SentenceTransformer, util
from src.etl.types import Prediction, QAExample
class SentenceEmbeddingQAModel(QAModel):
"""
Minimal embedding-based baseline: picks the single best matching sentence from the
context as the response. Uses sentence-transformers (https://sbert.net/) as
embedding-based representations of each of the context sentences as well as
the question itself. The sentence associated with the highest cosine similarity score
against the question is returned as the response.
"""
def __init__(self, config: SentenceEmbeddingModelConfig) -> None:
super().__init__()
assert isinstance(
config, SentenceEmbeddingModelConfig
), "Incompatible configuration object."
self.config = config
self._st_model = SentenceTransformer(
model_name_or_path=self.config.sentence_model_name,
device=self.config.device,
)
self._nlp = spacy.load("en_core_web_sm")
def train(
self,
train_examples: Optional[Dict[str, QAExample]] = None,
val_examples: Optional[Dict[str, QAExample]] = None,
) -> None:
"""
Nothing being explicitly trained for this model. Preserved for API consistency with super-class.
"""
return
def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
assert isinstance(examples, dict), "Incompatible input examples type."
predictions: Dict[str, Prediction] = {}
for qid, example in examples.items():
sentences = self._split_sentences(example.context)
if not sentences:
predictions[qid] = Prediction.null(question_id=qid)
continue
q_emb = self._st_model.encode(
example.question, convert_to_tensor=True, normalize_embeddings=True
)
s_emb = self._st_model.encode(
sentences, convert_to_tensor=True, normalize_embeddings=True
)
scores = util.cos_sim(q_emb, s_emb).squeeze(0)
top_index = int(scores.argmax().item())
best_sentence = sentences[top_index]
best_score = float(scores[top_index])
if best_score < self.config.no_answer_threshold:
predictions[qid] = Prediction.null(question_id=qid)
else:
predictions[qid] = Prediction(
question_id=qid,
predicted_answer=best_sentence,
confidence=best_score,
is_impossible=False,
)
return predictions
def _split_sentences(self, text: str) -> List[str]:
"""spacy-based sentence segmentation"""
text = (text or "").strip()
if not text:
return []
doc = self._nlp(text)
return [s.text.strip() for s in doc.sents if s.text.strip()]