nextword-pidgin-api / src /trigram_model.py
JermaineAI's picture
Fix API model loading: Copy src directory and update Dockerfile
ad18db6
"""
Trigram Language Model for Next-Word Prediction.
Implements a statistical trigram model with Laplace (add-one) smoothing
for Nigerian English/Pidgin next-word prediction.
Mathematical Foundation:
P(w_n | w_{n-2}, w_{n-1}) = (C(w_{n-2}, w_{n-1}, w_n) + α) / (C(w_{n-2}, w_{n-1}) + α|V|)
Where:
- C(.) = count of n-gram in training corpus
- α = smoothing parameter (1.0 for Laplace)
- |V| = vocabulary size
"""
from collections import Counter
from typing import List, Tuple, Dict, Optional
import math
class TrigramLM:
"""
Trigram Language Model with Laplace smoothing.
Attributes:
smoothing: Smoothing parameter (α). Default 1.0 for add-one smoothing.
unigram_counts: Counter for single word frequencies.
bigram_counts: Counter for word pair frequencies.
trigram_counts: Counter for word triple frequencies.
vocab: Set of all unique words in training corpus.
"""
def __init__(self, smoothing: float = 1.0):
"""
Initialize the trigram model.
Args:
smoothing: Laplace smoothing parameter. Higher values provide more
smoothing for unseen n-grams. Default 1.0 (add-one).
"""
self.smoothing = smoothing
self.unigram_counts: Counter = Counter()
self.bigram_counts: Counter = Counter()
self.trigram_counts: Counter = Counter()
self.vocab: set = set()
self._total_unigrams: int = 0
def train(self, sentences: List[List[str]]) -> None:
"""
Train the model by counting n-grams from tokenized sentences.
Expects sentences with start/end markers already added:
['<s>', '<s>', 'word1', 'word2', ..., '</s>']
Args:
sentences: List of tokenized sentences with boundary markers.
"""
for sentence in sentences:
# Build vocabulary
self.vocab.update(sentence)
# Count unigrams
for token in sentence:
self.unigram_counts[token] += 1
self._total_unigrams += 1
# Count bigrams
for i in range(len(sentence) - 1):
bigram = (sentence[i], sentence[i + 1])
self.bigram_counts[bigram] += 1
# Count trigrams
for i in range(len(sentence) - 2):
trigram = (sentence[i], sentence[i + 1], sentence[i + 2])
self.trigram_counts[trigram] += 1
print(f"Training complete:")
print(f" Vocabulary size: {len(self.vocab):,}")
print(f" Unique unigrams: {len(self.unigram_counts):,}")
print(f" Unique bigrams: {len(self.bigram_counts):,}")
print(f" Unique trigrams: {len(self.trigram_counts):,}")
def probability(self, w3: str, w1: str, w2: str) -> float:
"""
Compute P(w3 | w1, w2) with Laplace smoothing.
Formula:
P(w3|w1,w2) = (C(w1,w2,w3) + α) / (C(w1,w2) + α|V|)
Args:
w3: The word to predict.
w1: First context word (two positions before w3).
w2: Second context word (one position before w3).
Returns:
Conditional probability P(w3 | w1, w2).
"""
trigram_count = self.trigram_counts.get((w1, w2, w3), 0)
bigram_count = self.bigram_counts.get((w1, w2), 0)
vocab_size = len(self.vocab)
# Laplace smoothing
numerator = trigram_count + self.smoothing
denominator = bigram_count + (self.smoothing * vocab_size)
return numerator / denominator if denominator > 0 else 0.0
def log_probability(self, w3: str, w1: str, w2: str) -> float:
"""
Compute log P(w3 | w1, w2) for numerical stability.
Args:
w3: The word to predict.
w1: First context word.
w2: Second context word.
Returns:
Log probability.
"""
prob = self.probability(w3, w1, w2)
return math.log(prob) if prob > 0 else float('-inf')
def predict_next_words(
self,
context: str,
top_k: int = 5,
exclude_special: bool = True
) -> List[Tuple[str, float]]:
"""
Predict the top-k most likely next words given a context.
Args:
context: Input text (will use last two words as context).
top_k: Number of predictions to return.
exclude_special: If True, exclude <s> and </s> from predictions.
Returns:
List of (word, probability) tuples, sorted by probability descending.
"""
# Tokenize and extract last two words
words = context.lower().split()
if len(words) == 0:
w1, w2 = '<s>', '<s>'
elif len(words) == 1:
w1, w2 = '<s>', words[0]
else:
w1, w2 = words[-2], words[-1]
# Compute probability for each word in vocabulary
candidates = []
for word in self.vocab:
if exclude_special and word in ('<s>', '</s>'):
continue
prob = self.probability(word, w1, w2)
candidates.append((word, prob))
# Sort by probability descending
candidates.sort(key=lambda x: x[1], reverse=True)
return candidates[:top_k]
def sentence_probability(self, tokens: List[str]) -> float:
"""
Compute the probability of a sentence.
Args:
tokens: Tokenized sentence WITH start/end markers.
Returns:
Log probability of the sentence.
"""
if len(tokens) < 3:
return float('-inf')
log_prob = 0.0
for i in range(2, len(tokens)):
log_prob += self.log_probability(tokens[i], tokens[i-2], tokens[i-1])
return log_prob
def perplexity(self, sentences: List[List[str]]) -> float:
"""
Compute perplexity on a set of sentences.
Perplexity = exp(-1/N * sum(log P(w_i | w_{i-2}, w_{i-1})))
Lower perplexity = better model fit.
Args:
sentences: List of tokenized sentences with boundary markers.
Returns:
Perplexity score.
"""
total_log_prob = 0.0
total_words = 0
for sentence in sentences:
if len(sentence) < 3:
continue
for i in range(2, len(sentence)):
total_log_prob += self.log_probability(
sentence[i], sentence[i-2], sentence[i-1]
)
total_words += 1
if total_words == 0:
return float('inf')
avg_log_prob = total_log_prob / total_words
return math.exp(-avg_log_prob)
def get_context_distribution(
self,
w1: str,
w2: str,
top_k: Optional[int] = None
) -> List[Tuple[str, float]]:
"""
Get the probability distribution for a specific bigram context.
Args:
w1: First context word.
w2: Second context word.
top_k: If provided, return only top-k predictions.
Returns:
List of (word, probability) tuples.
"""
candidates = []
for word in self.vocab:
if word not in ('<s>', '</s>'):
prob = self.probability(word, w1, w2)
candidates.append((word, prob))
candidates.sort(key=lambda x: x[1], reverse=True)
if top_k:
return candidates[:top_k]
return candidates
def get_stats(self) -> Dict[str, int]:
"""
Get model statistics.
Returns:
Dictionary of statistics.
"""
return {
'vocab_size': len(self.vocab),
'unique_unigrams': len(self.unigram_counts),
'unique_bigrams': len(self.bigram_counts),
'unique_trigrams': len(self.trigram_counts),
'total_tokens': self._total_unigrams,
}
if __name__ == "__main__":
# Quick test with sample data
sample_sentences = [
['<s>', '<s>', 'i', 'dey', 'go', 'market', '</s>'],
['<s>', '<s>', 'i', 'dey', 'come', 'back', '</s>'],
['<s>', '<s>', 'you', 'dey', 'go', 'where', '?', '</s>'],
['<s>', '<s>', 'how', 'far', '?', '</s>'],
['<s>', '<s>', 'e', 'don', 'happen', '</s>'],
]
model = TrigramLM(smoothing=1.0)
model.train(sample_sentences)
print("\nTest Predictions:")
contexts = ["i dey", "you dey", "how"]
for ctx in contexts:
preds = model.predict_next_words(ctx, top_k=3)
print(f" '{ctx}' -> {preds}")