File size: 1,786 Bytes
6766ca8 c0210b5 5ba6210 ec19c57 6766ca8 ee1c4f9 0caea66 ee1c4f9 6766ca8 0caea66 6766ca8 ec19c57 6766ca8 ec19c57 ee1c4f9 ec19c57 0caea66 6766ca8 0caea66 ee1c4f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from typing import Dict, Any
import spacy
from environs import Env
from huggingface_hub import hf_hub_download
from joblib import load
from src.dtos.output.basic import BasicOutput
from src.format import format_model_name_from_path
from src.models.bag_of_words.extractor import BagOfWordsExtractor
from src.models.bag_of_words.formatter import BagOfWordsFormatter
from src.models.bag_of_words.model import BagOfWordsModelContainer
from src.models.bag_of_words.predictor import RelevancePredictor
SPACY_MODEL = spacy.load('en_core_web_trf', disable=['parser']) # Largest, slowest, most accurate model
class EndpointHandler:
def __init__(self, path: str):
env = Env()
env.read_env()
model_path = env.str("MODEL_PATH")
self.model_name = format_model_name_from_path(model_path)
downloaded_model_path = hf_hub_download(
repo_id="PDAP/url-relevance-models",
subfolder=model_path,
filename="model.joblib"
)
self.model_container: BagOfWordsModelContainer = load(downloaded_model_path)
self.extractor = BagOfWordsExtractor(self.model_container.permitted_terms)
self.formatter = BagOfWordsFormatter(self.model_container.term_label_encoder)
self.predictor = RelevancePredictor(self.model_container.model)
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
html = inputs["inputs"]
bag_of_words = self.extractor.extract_bag_of_words(html)
csr = self.formatter.format_bag_of_words(bag_of_words)
output = self.predictor.predict_relevance(csr)
return BasicOutput(
annotation=output.is_relevant,
confidence=output.probability,
model=self.model_name
).model_dump(mode="json")
|