JermaineAI commited on
Commit
ad18db6
·
1 Parent(s): 3de8536

Fix API model loading: Copy src directory and update Dockerfile

Browse files
Dockerfile CHANGED
@@ -8,6 +8,7 @@ RUN pip install --no-cache-dir -r requirements.txt
8
 
9
  # Copy application
10
  COPY api.py .
 
11
  COPY model/ model/
12
 
13
  # Expose port
 
8
 
9
  # Copy application
10
  COPY api.py .
11
+ COPY src/ src/
12
  COPY model/ model/
13
 
14
  # Expose port
api.py CHANGED
@@ -81,39 +81,8 @@ class LSTMLanguageModel(nn.Module):
81
  # =============================================================================
82
  # Trigram Model
83
  # =============================================================================
84
- class TrigramLM:
85
- def __init__(self, smoothing: float = 1.0):
86
- self.smoothing = smoothing
87
- self.unigram_counts = {}
88
- self.bigram_counts = {}
89
- self.trigram_counts = {}
90
- self.vocab = set()
91
-
92
- def probability(self, w3: str, w1: str, w2: str) -> float:
93
- trigram_count = self.trigram_counts.get((w1, w2, w3), 0)
94
- bigram_count = self.bigram_counts.get((w1, w2), 0)
95
- vocab_size = len(self.vocab)
96
- numerator = trigram_count + self.smoothing
97
- denominator = bigram_count + (self.smoothing * vocab_size)
98
- return numerator / denominator if denominator > 0 else 0.0
99
-
100
- def predict_next_words(self, context: str, top_k: int = 5) -> List[Tuple[str, float]]:
101
- words = context.lower().split()
102
- if len(words) == 0:
103
- w1, w2 = START_TOKEN, START_TOKEN
104
- elif len(words) == 1:
105
- w1, w2 = START_TOKEN, words[0]
106
- else:
107
- w1, w2 = words[-2], words[-1]
108
-
109
- candidates = []
110
- for word in self.vocab:
111
- if word not in (START_TOKEN, END_TOKEN, '<s>', '</s>'):
112
- prob = self.probability(word, w1, w2)
113
- candidates.append((word, prob))
114
-
115
- candidates.sort(key=lambda x: x[1], reverse=True)
116
- return candidates[:top_k]
117
 
118
  # =============================================================================
119
  # Global Models (loaded once at startup)
@@ -123,16 +92,6 @@ word_to_idx = None
123
  idx_to_word = None
124
  trigram_model = None
125
 
126
- # =============================================================================
127
- # Custom Unpickler to fix the 'src' module error
128
- # =============================================================================
129
- class PatchingUnpickler(pickle.Unpickler):
130
- def find_class(self, module, name):
131
- # If the pickle creates a dependency on 'src', redirect it to __main__
132
- if module.startswith("src") and name == "TrigramLM":
133
- return TrigramLM
134
- return super().find_class(module, name)
135
-
136
  @app.on_event("startup")
137
  async def load_models():
138
  global lstm_model, word_to_idx, idx_to_word, trigram_model
@@ -151,11 +110,10 @@ async def load_models():
151
  except Exception as e:
152
  print(f"Failed to load LSTM model: {e}")
153
 
154
- # 2. Load Trigram (Using the Custom Unpickler)
155
  try:
156
  with open('model/trigram_model.pkl', 'rb') as f:
157
- # Use PatchingUnpickler instead of standard pickle.load
158
- trigram_model = PatchingUnpickler(f).load()
159
  print(f"Trigram model loaded! Vocab size: {len(trigram_model.vocab)}")
160
  except Exception as e:
161
  print(f"Failed to load Trigram model: {e}")
 
81
  # =============================================================================
82
  # Trigram Model
83
  # =============================================================================
84
+ # Import directly from src to ensure compatibility with pickle
85
+ from src.trigram_model import TrigramLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # =============================================================================
88
  # Global Models (loaded once at startup)
 
92
  idx_to_word = None
93
  trigram_model = None
94
 
 
 
 
 
 
 
 
 
 
 
95
  @app.on_event("startup")
96
  async def load_models():
97
  global lstm_model, word_to_idx, idx_to_word, trigram_model
 
110
  except Exception as e:
111
  print(f"Failed to load LSTM model: {e}")
112
 
113
+ # 2. Load Trigram
114
  try:
115
  with open('model/trigram_model.pkl', 'rb') as f:
116
+ trigram_model = pickle.load(f)
 
117
  print(f"Trigram model loaded! Vocab size: {len(trigram_model.vocab)}")
118
  except Exception as e:
119
  print(f"Failed to load Trigram model: {e}")
src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Nigerian English/Pidgin Next-Word Prediction
2
+ # Trigram Language Model Baseline
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (156 Bytes). View file
 
src/__pycache__/data_loader.cpython-311.pyc ADDED
Binary file (7.43 kB). View file
 
src/__pycache__/preprocessing.cpython-311.pyc ADDED
Binary file (5.5 kB). View file
 
src/__pycache__/trigram_model.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
src/__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.27 kB). View file
 
src/data_loader.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading utilities for NaijaSenti and BBC Pidgin datasets.
3
+
4
+ Loads Nigerian Pidgin text from multiple sources for language modeling.
5
+ Sentiment/category labels are ignored.
6
+ """
7
+
8
+ from datasets import load_dataset
9
+ from typing import List, Dict, Any, Optional
10
+ import csv
11
+ import os
12
+
13
+ # Path to BBC Pidgin corpus (relative to project root)
14
+ BBC_PIDGIN_CORPUS_PATH = "bbc_pidgin_scraper/data/pidgin_corpus.csv"
15
+
16
+
17
+ def load_naijasenti_pcm() -> Dict[str, List[str]]:
18
+ """
19
+ Load the NaijaSenti PCM (Nigerian Pidgin) dataset.
20
+
21
+ Returns:
22
+ Dict with keys 'train', 'test', 'validation' containing text lists.
23
+ """
24
+ dataset = load_dataset("mteb/NaijaSenti", "pcm")
25
+
26
+ result = {}
27
+ for split in dataset.keys():
28
+ # Extract text field, ignore sentiment labels
29
+ result[split] = [example['text'] for example in dataset[split]]
30
+
31
+ return result
32
+
33
+
34
+ def load_bbc_pidgin(limit: Optional[int] = None, project_root: Optional[str] = None) -> List[str]:
35
+ """
36
+ Load BBC Pidgin articles from the scraped corpus.
37
+
38
+ The corpus contains headlines and article texts scraped from BBC Pidgin.
39
+ We concatenate headline + text for each article.
40
+
41
+ Args:
42
+ limit: Maximum number of articles to load. None for all.
43
+ project_root: Path to project root. Defaults to current working directory.
44
+
45
+ Returns:
46
+ List of article texts (headline + body combined).
47
+ """
48
+ if project_root is None:
49
+ project_root = os.getcwd()
50
+
51
+ corpus_path = os.path.join(project_root, BBC_PIDGIN_CORPUS_PATH)
52
+
53
+ if not os.path.exists(corpus_path):
54
+ print(f"Warning: BBC Pidgin corpus not found at {corpus_path}")
55
+ return []
56
+
57
+ texts = []
58
+ try:
59
+ with open(corpus_path, 'r', encoding='utf-8') as f:
60
+ reader = csv.DictReader(f)
61
+ for i, row in enumerate(reader):
62
+ if limit and i >= limit:
63
+ break
64
+ # Combine headline and text
65
+ headline = row.get('headline', '').strip()
66
+ text = row.get('text', '').strip()
67
+ if headline and text:
68
+ combined = f"{headline}. {text}"
69
+ texts.append(combined)
70
+ elif text:
71
+ texts.append(text)
72
+ except Exception as e:
73
+ print(f"Error loading BBC Pidgin corpus: {e}")
74
+ return []
75
+
76
+ print(f"Loaded {len(texts):,} BBC Pidgin articles")
77
+ return texts
78
+
79
+
80
+ def load_all_texts(include_bbc: bool = True, bbc_limit: Optional[int] = None) -> List[str]:
81
+ """
82
+ Load all text from all sources combined.
83
+
84
+ Combines NaijaSenti PCM dataset with BBC Pidgin articles
85
+ for maximum data coverage.
86
+
87
+ Args:
88
+ include_bbc: Whether to include BBC Pidgin articles.
89
+ bbc_limit: Maximum number of BBC articles to include.
90
+
91
+ Returns:
92
+ List of all text strings from all sources.
93
+ """
94
+ all_texts = []
95
+
96
+ # Load NaijaSenti
97
+ print("Loading NaijaSenti PCM dataset...")
98
+ splits = load_naijasenti_pcm()
99
+ for split_name, texts in splits.items():
100
+ all_texts.extend(texts)
101
+ print(f" Loaded {len(texts):,} texts from {split_name} split")
102
+
103
+ naija_total = len(all_texts)
104
+ print(f" NaijaSenti total: {naija_total:,} texts")
105
+
106
+ # Load BBC Pidgin
107
+ if include_bbc:
108
+ print(f"\nLoading BBC Pidgin corpus (limit={bbc_limit})...")
109
+ bbc_texts = load_bbc_pidgin(limit=bbc_limit)
110
+ all_texts.extend(bbc_texts)
111
+
112
+ print(f"\nCombined total: {len(all_texts):,} texts")
113
+ return all_texts
114
+
115
+
116
+ def get_dataset_stats(texts: List[str]) -> Dict[str, Any]:
117
+ """
118
+ Compute basic statistics about the dataset.
119
+
120
+ Args:
121
+ texts: List of text strings.
122
+
123
+ Returns:
124
+ Dictionary of statistics.
125
+ """
126
+ total_chars = sum(len(t) for t in texts)
127
+ total_words = sum(len(t.split()) for t in texts)
128
+
129
+ return {
130
+ 'num_texts': len(texts),
131
+ 'total_characters': total_chars,
132
+ 'total_words': total_words,
133
+ 'avg_words_per_text': total_words / len(texts) if texts else 0,
134
+ 'avg_chars_per_text': total_chars / len(texts) if texts else 0,
135
+ }
136
+
137
+
138
+ if __name__ == "__main__":
139
+ # Quick test
140
+ texts = load_all_texts(include_bbc=True) # Loads all BBC articles by default
141
+ stats = get_dataset_stats(texts)
142
+ print("\nDataset Statistics:")
143
+ for key, value in stats.items():
144
+ if isinstance(value, float):
145
+ print(f" {key}: {value:.2f}")
146
+ else:
147
+ print(f" {key}: {value:,}")
148
+
src/preprocessing.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text preprocessing pipeline for Nigerian English/Pidgin.
3
+
4
+ Design principles:
5
+ - Preserve linguistic features of Nigerian Pidgin (slang, contractions, code-switching)
6
+ - Remove noise (URLs, usernames) that don't contribute to language modeling
7
+ - Minimal normalization to avoid losing dialectal patterns
8
+ """
9
+
10
+ import re
11
+ from typing import List
12
+
13
+
14
+ # Special tokens for sentence boundaries
15
+ START_TOKEN = "<s>"
16
+ END_TOKEN = "</s>"
17
+
18
+
19
+ def clean_text(text: str) -> str:
20
+ """
21
+ Clean text while preserving Nigerian Pidgin features.
22
+
23
+ Operations:
24
+ 1. Lowercase (case doesn't matter for prediction)
25
+ 2. Remove URLs
26
+ 3. Remove @usernames (Twitter-style)
27
+ 4. Normalize whitespace
28
+
29
+ Preserved:
30
+ - Contractions (don't, I'm, na'm)
31
+ - Slang (abi, sha, sef)
32
+ - Code-switching patterns
33
+ - Pidgin grammar structures
34
+
35
+ Args:
36
+ text: Raw text string.
37
+
38
+ Returns:
39
+ Cleaned text string.
40
+ """
41
+ # Lowercase
42
+ text = text.lower()
43
+
44
+ # Remove URLs
45
+ text = re.sub(r'https?://\S+', '', text)
46
+ text = re.sub(r'www\.\S+', '', text)
47
+
48
+ # Remove @usernames
49
+ text = re.sub(r'@\w+', '', text)
50
+
51
+ # Remove hashtags but keep the word
52
+ text = re.sub(r'#(\w+)', r'\1', text)
53
+
54
+ # Normalize whitespace
55
+ text = re.sub(r'\s+', ' ', text)
56
+ text = text.strip()
57
+
58
+ return text
59
+
60
+
61
+ def tokenize(text: str) -> List[str]:
62
+ """
63
+ Word-level tokenization for Nigerian Pidgin.
64
+
65
+ Handles:
66
+ - Standard word boundaries
67
+ - Punctuation as separate tokens
68
+ - Preserves contractions as single tokens
69
+
70
+ Args:
71
+ text: Cleaned text string.
72
+
73
+ Returns:
74
+ List of tokens.
75
+ """
76
+ # Split on whitespace first
77
+ words = text.split()
78
+
79
+ tokens = []
80
+ for word in words:
81
+ # Handle punctuation attached to words
82
+ # Keep contractions together (don't, I'm, etc.)
83
+
84
+ # Strip leading punctuation
85
+ while word and word[0] in '.,!?;:"\'-([{':
86
+ if word[0] not in "'": # Keep leading apostrophe for contractions
87
+ tokens.append(word[0])
88
+ word = word[1:]
89
+
90
+ # Strip trailing punctuation
91
+ trailing = []
92
+ while word and word[-1] in '.,!?;:"\'-)]}"':
93
+ if word[-1] not in "'": # Keep trailing apostrophe for contractions
94
+ trailing.insert(0, word[-1])
95
+ word = word[:-1]
96
+
97
+ if word:
98
+ tokens.append(word)
99
+
100
+ tokens.extend(trailing)
101
+
102
+ return tokens
103
+
104
+
105
+ def preprocess_text(text: str) -> List[str]:
106
+ """
107
+ Full preprocessing pipeline: clean + tokenize.
108
+
109
+ Args:
110
+ text: Raw text string.
111
+
112
+ Returns:
113
+ List of tokens.
114
+ """
115
+ cleaned = clean_text(text)
116
+ tokens = tokenize(cleaned)
117
+ return tokens
118
+
119
+
120
+ def add_sentence_markers(tokens: List[str]) -> List[str]:
121
+ """
122
+ Add start/end markers for sentence boundary modeling.
123
+
124
+ For trigram models, we need context at sentence boundaries.
125
+ We add two start tokens to provide full context for the first word.
126
+
127
+ Args:
128
+ tokens: List of tokens from a sentence.
129
+
130
+ Returns:
131
+ Tokens with boundary markers.
132
+ """
133
+ if not tokens:
134
+ return []
135
+ return [START_TOKEN, START_TOKEN] + tokens + [END_TOKEN]
136
+
137
+
138
+ def preprocess_corpus(texts: List[str]) -> List[List[str]]:
139
+ """
140
+ Preprocess entire corpus for language model training.
141
+
142
+ Args:
143
+ texts: List of raw text strings.
144
+
145
+ Returns:
146
+ List of tokenized sentences with boundary markers.
147
+ """
148
+ processed = []
149
+ for text in texts:
150
+ tokens = preprocess_text(text)
151
+ if tokens: # Skip empty results
152
+ marked = add_sentence_markers(tokens)
153
+ processed.append(marked)
154
+ return processed
155
+
156
+
157
+ if __name__ == "__main__":
158
+ # Test preprocessing on Nigerian Pidgin examples
159
+ test_texts = [
160
+ "I dey go market, you wan follow?",
161
+ "That guy na correct person sha @handle https://example.com",
162
+ "Wetin you dey do? Abi you no sabi?",
163
+ "E don happen before, no be today matter",
164
+ "How far? Everything dey go well?",
165
+ ]
166
+
167
+ print("Preprocessing Examples:\n")
168
+ for text in test_texts:
169
+ tokens = preprocess_text(text)
170
+ marked = add_sentence_markers(tokens)
171
+ print(f"Original: {text}")
172
+ print(f"Tokens: {tokens}")
173
+ print(f"Marked: {marked}")
174
+ print()
src/trigram_model.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trigram Language Model for Next-Word Prediction.
3
+
4
+ Implements a statistical trigram model with Laplace (add-one) smoothing
5
+ for Nigerian English/Pidgin next-word prediction.
6
+
7
+ Mathematical Foundation:
8
+ P(w_n | w_{n-2}, w_{n-1}) = (C(w_{n-2}, w_{n-1}, w_n) + α) / (C(w_{n-2}, w_{n-1}) + α|V|)
9
+
10
+ Where:
11
+ - C(.) = count of n-gram in training corpus
12
+ - α = smoothing parameter (1.0 for Laplace)
13
+ - |V| = vocabulary size
14
+ """
15
+
16
+ from collections import Counter
17
+ from typing import List, Tuple, Dict, Optional
18
+ import math
19
+
20
+
21
+ class TrigramLM:
22
+ """
23
+ Trigram Language Model with Laplace smoothing.
24
+
25
+ Attributes:
26
+ smoothing: Smoothing parameter (α). Default 1.0 for add-one smoothing.
27
+ unigram_counts: Counter for single word frequencies.
28
+ bigram_counts: Counter for word pair frequencies.
29
+ trigram_counts: Counter for word triple frequencies.
30
+ vocab: Set of all unique words in training corpus.
31
+ """
32
+
33
+ def __init__(self, smoothing: float = 1.0):
34
+ """
35
+ Initialize the trigram model.
36
+
37
+ Args:
38
+ smoothing: Laplace smoothing parameter. Higher values provide more
39
+ smoothing for unseen n-grams. Default 1.0 (add-one).
40
+ """
41
+ self.smoothing = smoothing
42
+ self.unigram_counts: Counter = Counter()
43
+ self.bigram_counts: Counter = Counter()
44
+ self.trigram_counts: Counter = Counter()
45
+ self.vocab: set = set()
46
+ self._total_unigrams: int = 0
47
+
48
+ def train(self, sentences: List[List[str]]) -> None:
49
+ """
50
+ Train the model by counting n-grams from tokenized sentences.
51
+
52
+ Expects sentences with start/end markers already added:
53
+ ['<s>', '<s>', 'word1', 'word2', ..., '</s>']
54
+
55
+ Args:
56
+ sentences: List of tokenized sentences with boundary markers.
57
+ """
58
+ for sentence in sentences:
59
+ # Build vocabulary
60
+ self.vocab.update(sentence)
61
+
62
+ # Count unigrams
63
+ for token in sentence:
64
+ self.unigram_counts[token] += 1
65
+ self._total_unigrams += 1
66
+
67
+ # Count bigrams
68
+ for i in range(len(sentence) - 1):
69
+ bigram = (sentence[i], sentence[i + 1])
70
+ self.bigram_counts[bigram] += 1
71
+
72
+ # Count trigrams
73
+ for i in range(len(sentence) - 2):
74
+ trigram = (sentence[i], sentence[i + 1], sentence[i + 2])
75
+ self.trigram_counts[trigram] += 1
76
+
77
+ print(f"Training complete:")
78
+ print(f" Vocabulary size: {len(self.vocab):,}")
79
+ print(f" Unique unigrams: {len(self.unigram_counts):,}")
80
+ print(f" Unique bigrams: {len(self.bigram_counts):,}")
81
+ print(f" Unique trigrams: {len(self.trigram_counts):,}")
82
+
83
+ def probability(self, w3: str, w1: str, w2: str) -> float:
84
+ """
85
+ Compute P(w3 | w1, w2) with Laplace smoothing.
86
+
87
+ Formula:
88
+ P(w3|w1,w2) = (C(w1,w2,w3) + α) / (C(w1,w2) + α|V|)
89
+
90
+ Args:
91
+ w3: The word to predict.
92
+ w1: First context word (two positions before w3).
93
+ w2: Second context word (one position before w3).
94
+
95
+ Returns:
96
+ Conditional probability P(w3 | w1, w2).
97
+ """
98
+ trigram_count = self.trigram_counts.get((w1, w2, w3), 0)
99
+ bigram_count = self.bigram_counts.get((w1, w2), 0)
100
+ vocab_size = len(self.vocab)
101
+
102
+ # Laplace smoothing
103
+ numerator = trigram_count + self.smoothing
104
+ denominator = bigram_count + (self.smoothing * vocab_size)
105
+
106
+ return numerator / denominator if denominator > 0 else 0.0
107
+
108
+ def log_probability(self, w3: str, w1: str, w2: str) -> float:
109
+ """
110
+ Compute log P(w3 | w1, w2) for numerical stability.
111
+
112
+ Args:
113
+ w3: The word to predict.
114
+ w1: First context word.
115
+ w2: Second context word.
116
+
117
+ Returns:
118
+ Log probability.
119
+ """
120
+ prob = self.probability(w3, w1, w2)
121
+ return math.log(prob) if prob > 0 else float('-inf')
122
+
123
+ def predict_next_words(
124
+ self,
125
+ context: str,
126
+ top_k: int = 5,
127
+ exclude_special: bool = True
128
+ ) -> List[Tuple[str, float]]:
129
+ """
130
+ Predict the top-k most likely next words given a context.
131
+
132
+ Args:
133
+ context: Input text (will use last two words as context).
134
+ top_k: Number of predictions to return.
135
+ exclude_special: If True, exclude <s> and </s> from predictions.
136
+
137
+ Returns:
138
+ List of (word, probability) tuples, sorted by probability descending.
139
+ """
140
+ # Tokenize and extract last two words
141
+ words = context.lower().split()
142
+
143
+ if len(words) == 0:
144
+ w1, w2 = '<s>', '<s>'
145
+ elif len(words) == 1:
146
+ w1, w2 = '<s>', words[0]
147
+ else:
148
+ w1, w2 = words[-2], words[-1]
149
+
150
+ # Compute probability for each word in vocabulary
151
+ candidates = []
152
+ for word in self.vocab:
153
+ if exclude_special and word in ('<s>', '</s>'):
154
+ continue
155
+ prob = self.probability(word, w1, w2)
156
+ candidates.append((word, prob))
157
+
158
+ # Sort by probability descending
159
+ candidates.sort(key=lambda x: x[1], reverse=True)
160
+
161
+ return candidates[:top_k]
162
+
163
+ def sentence_probability(self, tokens: List[str]) -> float:
164
+ """
165
+ Compute the probability of a sentence.
166
+
167
+ Args:
168
+ tokens: Tokenized sentence WITH start/end markers.
169
+
170
+ Returns:
171
+ Log probability of the sentence.
172
+ """
173
+ if len(tokens) < 3:
174
+ return float('-inf')
175
+
176
+ log_prob = 0.0
177
+ for i in range(2, len(tokens)):
178
+ log_prob += self.log_probability(tokens[i], tokens[i-2], tokens[i-1])
179
+
180
+ return log_prob
181
+
182
+ def perplexity(self, sentences: List[List[str]]) -> float:
183
+ """
184
+ Compute perplexity on a set of sentences.
185
+
186
+ Perplexity = exp(-1/N * sum(log P(w_i | w_{i-2}, w_{i-1})))
187
+
188
+ Lower perplexity = better model fit.
189
+
190
+ Args:
191
+ sentences: List of tokenized sentences with boundary markers.
192
+
193
+ Returns:
194
+ Perplexity score.
195
+ """
196
+ total_log_prob = 0.0
197
+ total_words = 0
198
+
199
+ for sentence in sentences:
200
+ if len(sentence) < 3:
201
+ continue
202
+ for i in range(2, len(sentence)):
203
+ total_log_prob += self.log_probability(
204
+ sentence[i], sentence[i-2], sentence[i-1]
205
+ )
206
+ total_words += 1
207
+
208
+ if total_words == 0:
209
+ return float('inf')
210
+
211
+ avg_log_prob = total_log_prob / total_words
212
+ return math.exp(-avg_log_prob)
213
+
214
+ def get_context_distribution(
215
+ self,
216
+ w1: str,
217
+ w2: str,
218
+ top_k: Optional[int] = None
219
+ ) -> List[Tuple[str, float]]:
220
+ """
221
+ Get the probability distribution for a specific bigram context.
222
+
223
+ Args:
224
+ w1: First context word.
225
+ w2: Second context word.
226
+ top_k: If provided, return only top-k predictions.
227
+
228
+ Returns:
229
+ List of (word, probability) tuples.
230
+ """
231
+ candidates = []
232
+ for word in self.vocab:
233
+ if word not in ('<s>', '</s>'):
234
+ prob = self.probability(word, w1, w2)
235
+ candidates.append((word, prob))
236
+
237
+ candidates.sort(key=lambda x: x[1], reverse=True)
238
+
239
+ if top_k:
240
+ return candidates[:top_k]
241
+ return candidates
242
+
243
+ def get_stats(self) -> Dict[str, int]:
244
+ """
245
+ Get model statistics.
246
+
247
+ Returns:
248
+ Dictionary of statistics.
249
+ """
250
+ return {
251
+ 'vocab_size': len(self.vocab),
252
+ 'unique_unigrams': len(self.unigram_counts),
253
+ 'unique_bigrams': len(self.bigram_counts),
254
+ 'unique_trigrams': len(self.trigram_counts),
255
+ 'total_tokens': self._total_unigrams,
256
+ }
257
+
258
+
259
+ if __name__ == "__main__":
260
+ # Quick test with sample data
261
+ sample_sentences = [
262
+ ['<s>', '<s>', 'i', 'dey', 'go', 'market', '</s>'],
263
+ ['<s>', '<s>', 'i', 'dey', 'come', 'back', '</s>'],
264
+ ['<s>', '<s>', 'you', 'dey', 'go', 'where', '?', '</s>'],
265
+ ['<s>', '<s>', 'how', 'far', '?', '</s>'],
266
+ ['<s>', '<s>', 'e', 'don', 'happen', '</s>'],
267
+ ]
268
+
269
+ model = TrigramLM(smoothing=1.0)
270
+ model.train(sample_sentences)
271
+
272
+ print("\nTest Predictions:")
273
+ contexts = ["i dey", "you dey", "how"]
274
+ for ctx in contexts:
275
+ preds = model.predict_next_words(ctx, top_k=3)
276
+ print(f" '{ctx}' -> {preds}")
src/utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the next-word prediction system.
3
+ """
4
+
5
+ from typing import List, Tuple
6
+ import math
7
+
8
+
9
+ def format_predictions(predictions: List[Tuple[str, float]], show_percent: bool = True) -> str:
10
+ """
11
+ Format prediction results for display.
12
+
13
+ Args:
14
+ predictions: List of (word, probability) tuples.
15
+ show_percent: If True, show as percentage.
16
+
17
+ Returns:
18
+ Formatted string.
19
+ """
20
+ lines = []
21
+ for word, prob in predictions:
22
+ if show_percent:
23
+ lines.append(f" {word}: {prob*100:.2f}%")
24
+ else:
25
+ lines.append(f" {word}: {prob:.6f}")
26
+ return "\n".join(lines)
27
+
28
+
29
+ def calculate_entropy(probabilities: List[float]) -> float:
30
+ """
31
+ Calculate entropy of a probability distribution.
32
+
33
+ H(X) = -sum(p * log2(p))
34
+
35
+ Args:
36
+ probabilities: List of probabilities.
37
+
38
+ Returns:
39
+ Entropy in bits.
40
+ """
41
+ entropy = 0.0
42
+ for p in probabilities:
43
+ if p > 0:
44
+ entropy -= p * math.log2(p)
45
+ return entropy
46
+
47
+
48
+ def top_k_accuracy(
49
+ model,
50
+ test_sentences: List[List[str]],
51
+ k: int = 5
52
+ ) -> float:
53
+ """
54
+ Calculate top-k accuracy on test data.
55
+
56
+ Measures what fraction of true next words appear in top-k predictions.
57
+
58
+ Args:
59
+ model: Trained TrigramLM instance.
60
+ test_sentences: List of tokenized sentences with markers.
61
+ k: Number of top predictions to consider.
62
+
63
+ Returns:
64
+ Accuracy as fraction between 0 and 1.
65
+ """
66
+ correct = 0
67
+ total = 0
68
+
69
+ for sentence in test_sentences:
70
+ if len(sentence) < 3:
71
+ continue
72
+
73
+ for i in range(2, len(sentence)):
74
+ w1, w2 = sentence[i-2], sentence[i-1]
75
+ true_word = sentence[i]
76
+
77
+ # Get top-k predictions
78
+ preds = model.get_context_distribution(w1, w2, top_k=k)
79
+ pred_words = [w for w, _ in preds]
80
+
81
+ if true_word in pred_words:
82
+ correct += 1
83
+ total += 1
84
+
85
+ return correct / total if total > 0 else 0.0