""" Contains a simple baseline for the QA system. """ import spacy from typing import Dict, Optional, List from src.config.model_configs import SentenceEmbeddingModelConfig from src.models.base_qa_model import QAModel from sentence_transformers import SentenceTransformer, util from src.etl.types import Prediction, QAExample class SentenceEmbeddingQAModel(QAModel): """ Minimal embedding-based baseline: picks the single best matching sentence from the context as the response. Uses sentence-transformers (https://sbert.net/) as embedding-based representations of each of the context sentences as well as the question itself. The sentence associated with the highest cosine similarity score against the question is returned as the response. """ def __init__(self, config: SentenceEmbeddingModelConfig) -> None: super().__init__() assert isinstance( config, SentenceEmbeddingModelConfig ), "Incompatible configuration object." self.config = config self._st_model = SentenceTransformer( model_name_or_path=self.config.sentence_model_name, device=self.config.device, ) self._nlp = spacy.load("en_core_web_sm") def train( self, train_examples: Optional[Dict[str, QAExample]] = None, val_examples: Optional[Dict[str, QAExample]] = None, ) -> None: """ Nothing being explicitly trained for this model. Preserved for API consistency with super-class. """ return def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]: assert isinstance(examples, dict), "Incompatible input examples type." predictions: Dict[str, Prediction] = {} for qid, example in examples.items(): sentences = self._split_sentences(example.context) if not sentences: predictions[qid] = Prediction.null(question_id=qid) continue q_emb = self._st_model.encode( example.question, convert_to_tensor=True, normalize_embeddings=True ) s_emb = self._st_model.encode( sentences, convert_to_tensor=True, normalize_embeddings=True ) scores = util.cos_sim(q_emb, s_emb).squeeze(0) top_index = int(scores.argmax().item()) best_sentence = sentences[top_index] best_score = float(scores[top_index]) if best_score < self.config.no_answer_threshold: predictions[qid] = Prediction.null(question_id=qid) else: predictions[qid] = Prediction( question_id=qid, predicted_answer=best_sentence, confidence=best_score, is_impossible=False, ) return predictions def _split_sentences(self, text: str) -> List[str]: """spacy-based sentence segmentation""" text = (text or "").strip() if not text: return [] doc = self._nlp(text) return [s.text.strip() for s in doc.sents if s.text.strip()]