# Copyright (c) Guangsheng Bao. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os from utils import load_json from types import SimpleNamespace import numpy as np from scipy.stats import norm # Considering balanced classification that p(D0) equals to p(D1), we have # p(D1|x) = p(x|D1) / (p(x|D1) + p(x|D0)) def compute_prob_norm(x, mu0, sigma0, mu1, sigma1): pdf_value0 = norm.pdf(x, loc=mu0, scale=sigma0) pdf_value1 = norm.pdf(x, loc=mu1, scale=sigma1) prob = pdf_value1 / (pdf_value0 + pdf_value1) return prob class DetectorBase: def __init__(self, config_name): self.config_name = config_name self.config = self.load_config(config_name) def load_config(self, config_name): config = load_json(f'./configs/{config_name}.json') for key in config: val = config[key] if type(val) == str and val.startswith('${') and val.endswith('}'): val = os.getenv(val[2:-1]) config[key] = val # print(f'Config entry solved: {key} -> {val}') return SimpleNamespace(**config) def compute_crit(self, text): raise NotImplementedError def compute_prob(self, text): crit, ntoken = self.compute_crit(text) mu0 = self.config.classifier['mu0'] sigma0 = self.config.classifier['sigma0'] mu1 = self.config.classifier['mu1'] sigma1 = self.config.classifier['sigma1'] prob = compute_prob_norm(crit, mu0, sigma0, mu1, sigma1) return prob, crit, ntoken def __str__(self): return self.config_name CACHE_DETECTORS = {} def get_detector(name): from glimpse import Glimpse name_detectors = { 'glimpse': ('glimpse', Glimpse), } # lookup cache global CACHE_DETECTORS if name in CACHE_DETECTORS: return CACHE_DETECTORS[name] # create new if name in name_detectors: config_name, detector_class = name_detectors[name] detector = detector_class(config_name) CACHE_DETECTORS[name] = detector return detector else: raise NotImplementedError