Spaces:
Sleeping
Sleeping
| import math | |
| import re | |
| import dill | |
| from collections import Counter | |
| class SpamNaiveBayes: | |
| def __init__(self, alpha=1): | |
| self.alpha = alpha | |
| self.vocab = set() | |
| self.log_spam = {} | |
| self.log_ham = {} | |
| self.P_spam = 0 | |
| self.P_ham = 0 | |
| self.unk_spam = 0 | |
| self.unk_ham = 0 | |
| def tokenize(self, text): | |
| return re.findall(r"\w+|[!?.]", str(text).lower()) | |
| def train(self, texts, labels): | |
| # Build Vocab | |
| for t in texts: | |
| self.vocab.update(self.tokenize(t)) | |
| self.vocab = sorted(self.vocab) | |
| # Counts | |
| wc_spam = Counter() | |
| wc_ham = Counter() | |
| spam_docs = sum(1 for l in labels if l == 1) | |
| ham_docs = len(labels) - spam_docs | |
| total_docs = len(labels) | |
| for txt, lab in zip(texts, labels): | |
| toks = self.tokenize(txt) | |
| if lab == 1: | |
| wc_spam.update(toks) | |
| else: | |
| wc_ham.update(toks) | |
| # Calculate Probabilities | |
| self.P_spam = spam_docs / total_docs | |
| self.P_ham = ham_docs / total_docs | |
| V = len(self.vocab) | |
| total_spam = sum(wc_spam.values()) + self.alpha * V | |
| total_ham = sum(wc_ham.values()) + self.alpha * V | |
| self.log_spam = {w: math.log((wc_spam[w] + self.alpha) / total_spam) for w in self.vocab} | |
| self.log_ham = {w: math.log((wc_ham[w] + self.alpha) / total_ham) for w in self.vocab} | |
| self.unk_spam = math.log(self.alpha / total_spam) | |
| self.unk_ham = math.log(self.alpha / total_ham) | |
| print("Training Complete.") | |
| def predict(self, text): | |
| toks = self.tokenize(text) | |
| s_spam = math.log(self.P_spam + 1e-12) | |
| s_ham = math.log(self.P_ham + 1e-12) | |
| for t in toks: | |
| s_spam += self.log_spam.get(t, self.unk_spam) | |
| s_ham += self.log_ham.get(t, self.unk_ham) | |
| return 1 if s_spam > s_ham else 0 | |
| if __name__ == "__main__": | |
| from datasets import load_dataset | |
| print("Loading data...") | |
| ds = load_dataset("mshenoda/spam-messages") | |
| texts = [x['text'] for x in ds['train']] | |
| labels = [] | |
| for x in ds['train']: | |
| lab = x['label'] | |
| if isinstance(lab, str): | |
| labels.append(1 if lab.lower() in ['spam', '1'] else 0) | |
| else: | |
| labels.append(int(lab)) | |
| print("Training clean model...") | |
| model = SpamNaiveBayes() | |
| model.train(texts, labels) | |
| with open("model_nb_clean.pkl", "wb") as f: | |
| dill.dump(model, f) | |
| print("✅ Success! 'model_nb_clean.pkl' created. Upload this file to Hugging Face.") |