File size: 1,531 Bytes
3ccd59f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import pickle
import numpy as np
from collections import Counter
import os
class EndpointHandler:
def __init__(self, path=""):
# path is the directory where the model files are located
with open(os.path.join(path, "lang_id_model.pkl"), "rb") as f:
self.model_data = pickle.load(f)
self.W = self.model_data["W"]
self.feature_map = self.model_data["feature_map"]
self.ngram_length = self.model_data["ngram_length"]
self.lang_list = self.model_data["lang_list"]
def __call__(self, data):
"""
data: A dictionary containing the 'inputs' key from the UI
"""
inputs = data.pop("inputs", data)
# 1. Vectorize text (Standard N-gram logic)
ngrams = ["".join(s) for s in (zip(*[inputs[i:] for i in range(self.ngram_length)]))]
counts = Counter(ngrams)
x = np.zeros(len(self.feature_map))
for ngram, count in counts.items():
if ngram in self.feature_map:
x[self.feature_map[ngram]] = count
# 2. Add Bias and Compute Scores
x_aug = np.insert(x, 0, 1)
z = self.W.dot(x_aug)
# 3. Numerical Stable Softmax
probs = np.exp(z - np.max(z)) / np.exp(z - np.max(z)).sum()
# 4. Return formatted list for the HF Widget
return [{"label": self.lang_list[i], "score": float(probs[i])}
for i in range(len(self.lang_list))] |