| """ | |
| Contains a simple baseline for the QA system. | |
| """ | |
| import spacy | |
| from typing import Dict, Optional, List | |
| from src.config.model_configs import SentenceEmbeddingModelConfig | |
| from src.models.base_qa_model import QAModel | |
| from sentence_transformers import SentenceTransformer, util | |
| from src.etl.types import Prediction, QAExample | |
| class SentenceEmbeddingQAModel(QAModel): | |
| """ | |
| Minimal embedding-based baseline: picks the single best matching sentence from the | |
| context as the response. Uses sentence-transformers (https://sbert.net/) as | |
| embedding-based representations of each of the context sentences as well as | |
| the question itself. The sentence associated with the highest cosine similarity score | |
| against the question is returned as the response. | |
| """ | |
| def __init__(self, config: SentenceEmbeddingModelConfig) -> None: | |
| super().__init__() | |
| assert isinstance( | |
| config, SentenceEmbeddingModelConfig | |
| ), "Incompatible configuration object." | |
| self.config = config | |
| self._st_model = SentenceTransformer( | |
| model_name_or_path=self.config.sentence_model_name, | |
| device=self.config.device, | |
| ) | |
| self._nlp = spacy.load("en_core_web_sm") | |
| def train( | |
| self, | |
| train_examples: Optional[Dict[str, QAExample]] = None, | |
| val_examples: Optional[Dict[str, QAExample]] = None, | |
| ) -> None: | |
| """ | |
| Nothing being explicitly trained for this model. Preserved for API consistency with super-class. | |
| """ | |
| return | |
| def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]: | |
| assert isinstance(examples, dict), "Incompatible input examples type." | |
| predictions: Dict[str, Prediction] = {} | |
| for qid, example in examples.items(): | |
| sentences = self._split_sentences(example.context) | |
| if not sentences: | |
| predictions[qid] = Prediction.null(question_id=qid) | |
| continue | |
| q_emb = self._st_model.encode( | |
| example.question, convert_to_tensor=True, normalize_embeddings=True | |
| ) | |
| s_emb = self._st_model.encode( | |
| sentences, convert_to_tensor=True, normalize_embeddings=True | |
| ) | |
| scores = util.cos_sim(q_emb, s_emb).squeeze(0) | |
| top_index = int(scores.argmax().item()) | |
| best_sentence = sentences[top_index] | |
| best_score = float(scores[top_index]) | |
| if best_score < self.config.no_answer_threshold: | |
| predictions[qid] = Prediction.null(question_id=qid) | |
| else: | |
| predictions[qid] = Prediction( | |
| question_id=qid, | |
| predicted_answer=best_sentence, | |
| confidence=best_score, | |
| is_impossible=False, | |
| ) | |
| return predictions | |
| def _split_sentences(self, text: str) -> List[str]: | |
| """spacy-based sentence segmentation""" | |
| text = (text or "").strip() | |
| if not text: | |
| return [] | |
| doc = self._nlp(text) | |
| return [s.text.strip() for s in doc.sents if s.text.strip()] | |