File size: 1,434 Bytes
6766ca8
 
c0210b5
5ba6210
ec19c57
 
6766ca8
0caea66
 
 
 
6766ca8
0caea66
6766ca8
 
ec19c57
 
 
 
6766ca8
ec19c57
 
 
 
 
 
0caea66
 
 
 
6766ca8
 
 
0caea66
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Dict, Any

import spacy
from environs import Env
from huggingface_hub import hf_hub_download
from joblib import load

from src.models.bag_of_words.model import BagOfWordsModelContainer
from src.extract.core import BagOfWordsExtractor
from src.format.core import BagOfWordsFormatter
from src.predict.core import RelevancePredictor

SPACY_MODEL = spacy.load('en_core_web_trf', disable=['parser'])  # Largest, slowest, most accurate model


class EndpointHandler:
    def __init__(self, path: str):
        env = Env()
        env.read_env()

        model_path = env.str("MODEL_PATH")
        downloaded_model_path = hf_hub_download(
            repo_id="PDAP/url-relevance-models",
            subfolder=model_path,
            filename="model.joblib"
        )
        self.model_container: BagOfWordsModelContainer = load(downloaded_model_path)
        self.extractor = BagOfWordsExtractor(self.model_container.permitted_terms)
        self.formatter = BagOfWordsFormatter(self.model_container.term_label_encoder)
        self.predictor = RelevancePredictor(self.model_container.model)

    def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        html = inputs["inputs"]
        bag_of_words = self.extractor.extract_bag_of_words(html)
        csr = self.formatter.format_bag_of_words(bag_of_words)
        output = self.predictor.predict_relevance(csr)
        return output.model_dump(mode="json")