File size: 3,172 Bytes
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Contains a simple baseline for the QA system.
"""

import spacy
from typing import Dict, Optional, List
from src.config.model_configs import SentenceEmbeddingModelConfig
from src.models.base_qa_model import QAModel
from sentence_transformers import SentenceTransformer, util
from src.etl.types import Prediction, QAExample


class SentenceEmbeddingQAModel(QAModel):
    """
    Minimal embedding-based baseline: picks the single best matching sentence from the
    context as the response. Uses sentence-transformers (https://sbert.net/) as
    embedding-based representations of each of the context sentences as well as
    the question itself. The sentence associated with the highest cosine similarity score
    against the question is returned as the response.
    """

    def __init__(self, config: SentenceEmbeddingModelConfig) -> None:
        super().__init__()
        assert isinstance(
            config, SentenceEmbeddingModelConfig
        ), "Incompatible configuration object."
        self.config = config
        self._st_model = SentenceTransformer(
            model_name_or_path=self.config.sentence_model_name,
            device=self.config.device,
        )
        self._nlp = spacy.load("en_core_web_sm")

    def train(
        self,
        train_examples: Optional[Dict[str, QAExample]] = None,
        val_examples: Optional[Dict[str, QAExample]] = None,
    ) -> None:
        """
        Nothing being explicitly trained for this model. Preserved for API consistency with super-class.
        """
        return

    def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
        assert isinstance(examples, dict), "Incompatible input examples type."

        predictions: Dict[str, Prediction] = {}
        for qid, example in examples.items():
            sentences = self._split_sentences(example.context)
            if not sentences:
                predictions[qid] = Prediction.null(question_id=qid)
                continue

            q_emb = self._st_model.encode(
                example.question, convert_to_tensor=True, normalize_embeddings=True
            )
            s_emb = self._st_model.encode(
                sentences, convert_to_tensor=True, normalize_embeddings=True
            )
            scores = util.cos_sim(q_emb, s_emb).squeeze(0)
            top_index = int(scores.argmax().item())
            best_sentence = sentences[top_index]
            best_score = float(scores[top_index])

            if best_score < self.config.no_answer_threshold:
                predictions[qid] = Prediction.null(question_id=qid)
            else:
                predictions[qid] = Prediction(
                    question_id=qid,
                    predicted_answer=best_sentence,
                    confidence=best_score,
                    is_impossible=False,
                )
        return predictions

    def _split_sentences(self, text: str) -> List[str]:
        """spacy-based sentence segmentation"""
        text = (text or "").strip()
        if not text:
            return []
        doc = self._nlp(text)
        return [s.text.strip() for s in doc.sents if s.text.strip()]