spelling-error-correction / ngram_model.py
hatakekksheeshh's picture
Upload ngram_model.py with huggingface_hub
e8f701d verified
import numpy as np
from collections import defaultdict
class NgramLanguageModel:
def __init__(self, n, sentences, vocabulary):
# Add special tokens
self.START = '<s>'
self.END = '</s>'
vocabulary.add(self.START)
vocabulary.add(self.END)
self.n = n
self.sentences = sentences
self.vocabulary = vocabulary
self.vocab_size = len(self.vocabulary)
# Build n-gram counts
self.ngram_counts = defaultdict(int)
self.context_counts = defaultdict(int)
self.build_model()
def build_model(self):
for sentence in self.sentences:
# Add start and end tokens
padded_sentence = [self.START] * (self.n - 1) + sentence + [self.END]
# Count n-grams
for i in range(len(padded_sentence) - self.n + 1):
ngram = tuple(padded_sentence[i:i + self.n])
context = ngram[:-1] if self.n > 1 else ()
self.ngram_counts[ngram] += 1
if self.n > 1:
self.context_counts[context] += 1
print(f"{self.n}-gram model built!")
print(f"Unique {self.n}-grams: {len(self.ngram_counts):,}")
def get_probability(self, ngram):
ngram = tuple(ngram)
if self.n == 1:
# Unigram: P(w) = (count(w) + 1) / (total_words + V)
total_words = sum(self.ngram_counts.values())
count = self.ngram_counts.get(ngram, 0)
prob = (count + 1) / (total_words + self.vocab_size)
else:
# N-gram: P(w_n | context) = (count(context, w_n) + 1) / (count(context) + V)
context = ngram[:-1]
count = self.ngram_counts.get(ngram, 0)
context_count = self.context_counts.get(context, 0)
prob = (count + 1) / (context_count + self.vocab_size)
return prob
def get_sentence_probability(self, sentence):
# Add padding
padded_sentence = [self.START] * (self.n - 1) + sentence + [self.END]
log_prob = 0.0
for i in range(len(padded_sentence) - self.n + 1):
ngram = padded_sentence[i:i + self.n]
prob = self.get_probability(ngram)
log_prob += np.log2(prob)
return 2 ** log_prob, log_prob
def get_perplexity(self, sentence):
_, log_prob = self.get_sentence_probability(sentence)
N = len(sentence) + 1
perplexity = 2 ** (-log_prob / N)
return perplexity