File size: 1,786 Bytes
6766ca8
 
c0210b5
5ba6210
ec19c57
 
6766ca8
ee1c4f9
 
 
 
 
0caea66
ee1c4f9
6766ca8
0caea66
6766ca8
 
ec19c57
 
 
 
6766ca8
ec19c57
ee1c4f9
ec19c57
 
 
 
 
0caea66
 
 
 
6766ca8
 
 
0caea66
 
 
ee1c4f9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Dict, Any

import spacy
from environs import Env
from huggingface_hub import hf_hub_download
from joblib import load

from src.dtos.output.basic import BasicOutput

from src.format import format_model_name_from_path
from src.models.bag_of_words.extractor import BagOfWordsExtractor
from src.models.bag_of_words.formatter import BagOfWordsFormatter
from src.models.bag_of_words.model import BagOfWordsModelContainer
from src.models.bag_of_words.predictor import RelevancePredictor

SPACY_MODEL = spacy.load('en_core_web_trf', disable=['parser'])  # Largest, slowest, most accurate model


class EndpointHandler:
    def __init__(self, path: str):
        env = Env()
        env.read_env()

        model_path = env.str("MODEL_PATH")
        self.model_name = format_model_name_from_path(model_path)
        downloaded_model_path = hf_hub_download(
            repo_id="PDAP/url-relevance-models",
            subfolder=model_path,
            filename="model.joblib"
        )
        self.model_container: BagOfWordsModelContainer = load(downloaded_model_path)
        self.extractor = BagOfWordsExtractor(self.model_container.permitted_terms)
        self.formatter = BagOfWordsFormatter(self.model_container.term_label_encoder)
        self.predictor = RelevancePredictor(self.model_container.model)

    def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        html = inputs["inputs"]
        bag_of_words = self.extractor.extract_bag_of_words(html)
        csr = self.formatter.format_bag_of_words(bag_of_words)
        output = self.predictor.predict_relevance(csr)
        return BasicOutput(
            annotation=output.is_relevant,
            confidence=output.probability,
            model=self.model_name
        ).model_dump(mode="json")