glimpse / detector_base.py
Guangsheng Bao
update classifier
382f974
# Copyright (c) Guangsheng Bao.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from utils import load_json
from types import SimpleNamespace
import numpy as np
from scipy.stats import norm
# Considering balanced classification that p(D0) equals to p(D1), we have
# p(D1|x) = p(x|D1) / (p(x|D1) + p(x|D0))
def compute_prob_norm(x, mu0, sigma0, mu1, sigma1):
pdf_value0 = norm.pdf(x, loc=mu0, scale=sigma0)
pdf_value1 = norm.pdf(x, loc=mu1, scale=sigma1)
prob = pdf_value1 / (pdf_value0 + pdf_value1)
return prob
class DetectorBase:
def __init__(self, config_name):
self.config_name = config_name
self.config = self.load_config(config_name)
def load_config(self, config_name):
config = load_json(f'./configs/{config_name}.json')
for key in config:
val = config[key]
if type(val) == str and val.startswith('${') and val.endswith('}'):
val = os.getenv(val[2:-1])
config[key] = val
# print(f'Config entry solved: {key} -> {val}')
return SimpleNamespace(**config)
def compute_crit(self, text):
raise NotImplementedError
def compute_prob(self, text):
crit, ntoken = self.compute_crit(text)
mu0 = self.config.classifier['mu0']
sigma0 = self.config.classifier['sigma0']
mu1 = self.config.classifier['mu1']
sigma1 = self.config.classifier['sigma1']
prob = compute_prob_norm(crit, mu0, sigma0, mu1, sigma1)
return prob, crit, ntoken
def __str__(self):
return self.config_name
CACHE_DETECTORS = {}
def get_detector(name):
from glimpse import Glimpse
name_detectors = {
'glimpse': ('glimpse', Glimpse),
}
# lookup cache
global CACHE_DETECTORS
if name in CACHE_DETECTORS:
return CACHE_DETECTORS[name]
# create new
if name in name_detectors:
config_name, detector_class = name_detectors[name]
detector = detector_class(config_name)
CACHE_DETECTORS[name] = detector
return detector
else:
raise NotImplementedError