Spaces:
Running
Running
Commit ·
ad18db6
1
Parent(s): 3de8536
Fix API model loading: Copy src directory and update Dockerfile
Browse files- Dockerfile +1 -0
- api.py +4 -46
- src/__init__.py +2 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/data_loader.cpython-311.pyc +0 -0
- src/__pycache__/preprocessing.cpython-311.pyc +0 -0
- src/__pycache__/trigram_model.cpython-311.pyc +0 -0
- src/__pycache__/utils.cpython-311.pyc +0 -0
- src/data_loader.py +148 -0
- src/preprocessing.py +174 -0
- src/trigram_model.py +276 -0
- src/utils.py +85 -0
Dockerfile
CHANGED
|
@@ -8,6 +8,7 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|
| 8 |
|
| 9 |
# Copy application
|
| 10 |
COPY api.py .
|
|
|
|
| 11 |
COPY model/ model/
|
| 12 |
|
| 13 |
# Expose port
|
|
|
|
| 8 |
|
| 9 |
# Copy application
|
| 10 |
COPY api.py .
|
| 11 |
+
COPY src/ src/
|
| 12 |
COPY model/ model/
|
| 13 |
|
| 14 |
# Expose port
|
api.py
CHANGED
|
@@ -81,39 +81,8 @@ class LSTMLanguageModel(nn.Module):
|
|
| 81 |
# =============================================================================
|
| 82 |
# Trigram Model
|
| 83 |
# =============================================================================
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
self.smoothing = smoothing
|
| 87 |
-
self.unigram_counts = {}
|
| 88 |
-
self.bigram_counts = {}
|
| 89 |
-
self.trigram_counts = {}
|
| 90 |
-
self.vocab = set()
|
| 91 |
-
|
| 92 |
-
def probability(self, w3: str, w1: str, w2: str) -> float:
|
| 93 |
-
trigram_count = self.trigram_counts.get((w1, w2, w3), 0)
|
| 94 |
-
bigram_count = self.bigram_counts.get((w1, w2), 0)
|
| 95 |
-
vocab_size = len(self.vocab)
|
| 96 |
-
numerator = trigram_count + self.smoothing
|
| 97 |
-
denominator = bigram_count + (self.smoothing * vocab_size)
|
| 98 |
-
return numerator / denominator if denominator > 0 else 0.0
|
| 99 |
-
|
| 100 |
-
def predict_next_words(self, context: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
| 101 |
-
words = context.lower().split()
|
| 102 |
-
if len(words) == 0:
|
| 103 |
-
w1, w2 = START_TOKEN, START_TOKEN
|
| 104 |
-
elif len(words) == 1:
|
| 105 |
-
w1, w2 = START_TOKEN, words[0]
|
| 106 |
-
else:
|
| 107 |
-
w1, w2 = words[-2], words[-1]
|
| 108 |
-
|
| 109 |
-
candidates = []
|
| 110 |
-
for word in self.vocab:
|
| 111 |
-
if word not in (START_TOKEN, END_TOKEN, '<s>', '</s>'):
|
| 112 |
-
prob = self.probability(word, w1, w2)
|
| 113 |
-
candidates.append((word, prob))
|
| 114 |
-
|
| 115 |
-
candidates.sort(key=lambda x: x[1], reverse=True)
|
| 116 |
-
return candidates[:top_k]
|
| 117 |
|
| 118 |
# =============================================================================
|
| 119 |
# Global Models (loaded once at startup)
|
|
@@ -123,16 +92,6 @@ word_to_idx = None
|
|
| 123 |
idx_to_word = None
|
| 124 |
trigram_model = None
|
| 125 |
|
| 126 |
-
# =============================================================================
|
| 127 |
-
# Custom Unpickler to fix the 'src' module error
|
| 128 |
-
# =============================================================================
|
| 129 |
-
class PatchingUnpickler(pickle.Unpickler):
|
| 130 |
-
def find_class(self, module, name):
|
| 131 |
-
# If the pickle creates a dependency on 'src', redirect it to __main__
|
| 132 |
-
if module.startswith("src") and name == "TrigramLM":
|
| 133 |
-
return TrigramLM
|
| 134 |
-
return super().find_class(module, name)
|
| 135 |
-
|
| 136 |
@app.on_event("startup")
|
| 137 |
async def load_models():
|
| 138 |
global lstm_model, word_to_idx, idx_to_word, trigram_model
|
|
@@ -151,11 +110,10 @@ async def load_models():
|
|
| 151 |
except Exception as e:
|
| 152 |
print(f"Failed to load LSTM model: {e}")
|
| 153 |
|
| 154 |
-
# 2. Load Trigram
|
| 155 |
try:
|
| 156 |
with open('model/trigram_model.pkl', 'rb') as f:
|
| 157 |
-
|
| 158 |
-
trigram_model = PatchingUnpickler(f).load()
|
| 159 |
print(f"Trigram model loaded! Vocab size: {len(trigram_model.vocab)}")
|
| 160 |
except Exception as e:
|
| 161 |
print(f"Failed to load Trigram model: {e}")
|
|
|
|
| 81 |
# =============================================================================
|
| 82 |
# Trigram Model
|
| 83 |
# =============================================================================
|
| 84 |
+
# Import directly from src to ensure compatibility with pickle
|
| 85 |
+
from src.trigram_model import TrigramLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# =============================================================================
|
| 88 |
# Global Models (loaded once at startup)
|
|
|
|
| 92 |
idx_to_word = None
|
| 93 |
trigram_model = None
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
@app.on_event("startup")
|
| 96 |
async def load_models():
|
| 97 |
global lstm_model, word_to_idx, idx_to_word, trigram_model
|
|
|
|
| 110 |
except Exception as e:
|
| 111 |
print(f"Failed to load LSTM model: {e}")
|
| 112 |
|
| 113 |
+
# 2. Load Trigram
|
| 114 |
try:
|
| 115 |
with open('model/trigram_model.pkl', 'rb') as f:
|
| 116 |
+
trigram_model = pickle.load(f)
|
|
|
|
| 117 |
print(f"Trigram model loaded! Vocab size: {len(trigram_model.vocab)}")
|
| 118 |
except Exception as e:
|
| 119 |
print(f"Failed to load Trigram model: {e}")
|
src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Nigerian English/Pidgin Next-Word Prediction
|
| 2 |
+
# Trigram Language Model Baseline
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
src/__pycache__/data_loader.cpython-311.pyc
ADDED
|
Binary file (7.43 kB). View file
|
|
|
src/__pycache__/preprocessing.cpython-311.pyc
ADDED
|
Binary file (5.5 kB). View file
|
|
|
src/__pycache__/trigram_model.cpython-311.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
src/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (3.27 kB). View file
|
|
|
src/data_loader.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading utilities for NaijaSenti and BBC Pidgin datasets.
|
| 3 |
+
|
| 4 |
+
Loads Nigerian Pidgin text from multiple sources for language modeling.
|
| 5 |
+
Sentiment/category labels are ignored.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
import csv
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
# Path to BBC Pidgin corpus (relative to project root)
|
| 14 |
+
BBC_PIDGIN_CORPUS_PATH = "bbc_pidgin_scraper/data/pidgin_corpus.csv"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_naijasenti_pcm() -> Dict[str, List[str]]:
|
| 18 |
+
"""
|
| 19 |
+
Load the NaijaSenti PCM (Nigerian Pidgin) dataset.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Dict with keys 'train', 'test', 'validation' containing text lists.
|
| 23 |
+
"""
|
| 24 |
+
dataset = load_dataset("mteb/NaijaSenti", "pcm")
|
| 25 |
+
|
| 26 |
+
result = {}
|
| 27 |
+
for split in dataset.keys():
|
| 28 |
+
# Extract text field, ignore sentiment labels
|
| 29 |
+
result[split] = [example['text'] for example in dataset[split]]
|
| 30 |
+
|
| 31 |
+
return result
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_bbc_pidgin(limit: Optional[int] = None, project_root: Optional[str] = None) -> List[str]:
|
| 35 |
+
"""
|
| 36 |
+
Load BBC Pidgin articles from the scraped corpus.
|
| 37 |
+
|
| 38 |
+
The corpus contains headlines and article texts scraped from BBC Pidgin.
|
| 39 |
+
We concatenate headline + text for each article.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
limit: Maximum number of articles to load. None for all.
|
| 43 |
+
project_root: Path to project root. Defaults to current working directory.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
List of article texts (headline + body combined).
|
| 47 |
+
"""
|
| 48 |
+
if project_root is None:
|
| 49 |
+
project_root = os.getcwd()
|
| 50 |
+
|
| 51 |
+
corpus_path = os.path.join(project_root, BBC_PIDGIN_CORPUS_PATH)
|
| 52 |
+
|
| 53 |
+
if not os.path.exists(corpus_path):
|
| 54 |
+
print(f"Warning: BBC Pidgin corpus not found at {corpus_path}")
|
| 55 |
+
return []
|
| 56 |
+
|
| 57 |
+
texts = []
|
| 58 |
+
try:
|
| 59 |
+
with open(corpus_path, 'r', encoding='utf-8') as f:
|
| 60 |
+
reader = csv.DictReader(f)
|
| 61 |
+
for i, row in enumerate(reader):
|
| 62 |
+
if limit and i >= limit:
|
| 63 |
+
break
|
| 64 |
+
# Combine headline and text
|
| 65 |
+
headline = row.get('headline', '').strip()
|
| 66 |
+
text = row.get('text', '').strip()
|
| 67 |
+
if headline and text:
|
| 68 |
+
combined = f"{headline}. {text}"
|
| 69 |
+
texts.append(combined)
|
| 70 |
+
elif text:
|
| 71 |
+
texts.append(text)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error loading BBC Pidgin corpus: {e}")
|
| 74 |
+
return []
|
| 75 |
+
|
| 76 |
+
print(f"Loaded {len(texts):,} BBC Pidgin articles")
|
| 77 |
+
return texts
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_all_texts(include_bbc: bool = True, bbc_limit: Optional[int] = None) -> List[str]:
|
| 81 |
+
"""
|
| 82 |
+
Load all text from all sources combined.
|
| 83 |
+
|
| 84 |
+
Combines NaijaSenti PCM dataset with BBC Pidgin articles
|
| 85 |
+
for maximum data coverage.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
include_bbc: Whether to include BBC Pidgin articles.
|
| 89 |
+
bbc_limit: Maximum number of BBC articles to include.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
List of all text strings from all sources.
|
| 93 |
+
"""
|
| 94 |
+
all_texts = []
|
| 95 |
+
|
| 96 |
+
# Load NaijaSenti
|
| 97 |
+
print("Loading NaijaSenti PCM dataset...")
|
| 98 |
+
splits = load_naijasenti_pcm()
|
| 99 |
+
for split_name, texts in splits.items():
|
| 100 |
+
all_texts.extend(texts)
|
| 101 |
+
print(f" Loaded {len(texts):,} texts from {split_name} split")
|
| 102 |
+
|
| 103 |
+
naija_total = len(all_texts)
|
| 104 |
+
print(f" NaijaSenti total: {naija_total:,} texts")
|
| 105 |
+
|
| 106 |
+
# Load BBC Pidgin
|
| 107 |
+
if include_bbc:
|
| 108 |
+
print(f"\nLoading BBC Pidgin corpus (limit={bbc_limit})...")
|
| 109 |
+
bbc_texts = load_bbc_pidgin(limit=bbc_limit)
|
| 110 |
+
all_texts.extend(bbc_texts)
|
| 111 |
+
|
| 112 |
+
print(f"\nCombined total: {len(all_texts):,} texts")
|
| 113 |
+
return all_texts
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_dataset_stats(texts: List[str]) -> Dict[str, Any]:
|
| 117 |
+
"""
|
| 118 |
+
Compute basic statistics about the dataset.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
texts: List of text strings.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Dictionary of statistics.
|
| 125 |
+
"""
|
| 126 |
+
total_chars = sum(len(t) for t in texts)
|
| 127 |
+
total_words = sum(len(t.split()) for t in texts)
|
| 128 |
+
|
| 129 |
+
return {
|
| 130 |
+
'num_texts': len(texts),
|
| 131 |
+
'total_characters': total_chars,
|
| 132 |
+
'total_words': total_words,
|
| 133 |
+
'avg_words_per_text': total_words / len(texts) if texts else 0,
|
| 134 |
+
'avg_chars_per_text': total_chars / len(texts) if texts else 0,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
# Quick test
|
| 140 |
+
texts = load_all_texts(include_bbc=True) # Loads all BBC articles by default
|
| 141 |
+
stats = get_dataset_stats(texts)
|
| 142 |
+
print("\nDataset Statistics:")
|
| 143 |
+
for key, value in stats.items():
|
| 144 |
+
if isinstance(value, float):
|
| 145 |
+
print(f" {key}: {value:.2f}")
|
| 146 |
+
else:
|
| 147 |
+
print(f" {key}: {value:,}")
|
| 148 |
+
|
src/preprocessing.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text preprocessing pipeline for Nigerian English/Pidgin.
|
| 3 |
+
|
| 4 |
+
Design principles:
|
| 5 |
+
- Preserve linguistic features of Nigerian Pidgin (slang, contractions, code-switching)
|
| 6 |
+
- Remove noise (URLs, usernames) that don't contribute to language modeling
|
| 7 |
+
- Minimal normalization to avoid losing dialectal patterns
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Special tokens for sentence boundaries
|
| 15 |
+
START_TOKEN = "<s>"
|
| 16 |
+
END_TOKEN = "</s>"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def clean_text(text: str) -> str:
|
| 20 |
+
"""
|
| 21 |
+
Clean text while preserving Nigerian Pidgin features.
|
| 22 |
+
|
| 23 |
+
Operations:
|
| 24 |
+
1. Lowercase (case doesn't matter for prediction)
|
| 25 |
+
2. Remove URLs
|
| 26 |
+
3. Remove @usernames (Twitter-style)
|
| 27 |
+
4. Normalize whitespace
|
| 28 |
+
|
| 29 |
+
Preserved:
|
| 30 |
+
- Contractions (don't, I'm, na'm)
|
| 31 |
+
- Slang (abi, sha, sef)
|
| 32 |
+
- Code-switching patterns
|
| 33 |
+
- Pidgin grammar structures
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
text: Raw text string.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Cleaned text string.
|
| 40 |
+
"""
|
| 41 |
+
# Lowercase
|
| 42 |
+
text = text.lower()
|
| 43 |
+
|
| 44 |
+
# Remove URLs
|
| 45 |
+
text = re.sub(r'https?://\S+', '', text)
|
| 46 |
+
text = re.sub(r'www\.\S+', '', text)
|
| 47 |
+
|
| 48 |
+
# Remove @usernames
|
| 49 |
+
text = re.sub(r'@\w+', '', text)
|
| 50 |
+
|
| 51 |
+
# Remove hashtags but keep the word
|
| 52 |
+
text = re.sub(r'#(\w+)', r'\1', text)
|
| 53 |
+
|
| 54 |
+
# Normalize whitespace
|
| 55 |
+
text = re.sub(r'\s+', ' ', text)
|
| 56 |
+
text = text.strip()
|
| 57 |
+
|
| 58 |
+
return text
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def tokenize(text: str) -> List[str]:
|
| 62 |
+
"""
|
| 63 |
+
Word-level tokenization for Nigerian Pidgin.
|
| 64 |
+
|
| 65 |
+
Handles:
|
| 66 |
+
- Standard word boundaries
|
| 67 |
+
- Punctuation as separate tokens
|
| 68 |
+
- Preserves contractions as single tokens
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
text: Cleaned text string.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
List of tokens.
|
| 75 |
+
"""
|
| 76 |
+
# Split on whitespace first
|
| 77 |
+
words = text.split()
|
| 78 |
+
|
| 79 |
+
tokens = []
|
| 80 |
+
for word in words:
|
| 81 |
+
# Handle punctuation attached to words
|
| 82 |
+
# Keep contractions together (don't, I'm, etc.)
|
| 83 |
+
|
| 84 |
+
# Strip leading punctuation
|
| 85 |
+
while word and word[0] in '.,!?;:"\'-([{':
|
| 86 |
+
if word[0] not in "'": # Keep leading apostrophe for contractions
|
| 87 |
+
tokens.append(word[0])
|
| 88 |
+
word = word[1:]
|
| 89 |
+
|
| 90 |
+
# Strip trailing punctuation
|
| 91 |
+
trailing = []
|
| 92 |
+
while word and word[-1] in '.,!?;:"\'-)]}"':
|
| 93 |
+
if word[-1] not in "'": # Keep trailing apostrophe for contractions
|
| 94 |
+
trailing.insert(0, word[-1])
|
| 95 |
+
word = word[:-1]
|
| 96 |
+
|
| 97 |
+
if word:
|
| 98 |
+
tokens.append(word)
|
| 99 |
+
|
| 100 |
+
tokens.extend(trailing)
|
| 101 |
+
|
| 102 |
+
return tokens
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def preprocess_text(text: str) -> List[str]:
|
| 106 |
+
"""
|
| 107 |
+
Full preprocessing pipeline: clean + tokenize.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
text: Raw text string.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
List of tokens.
|
| 114 |
+
"""
|
| 115 |
+
cleaned = clean_text(text)
|
| 116 |
+
tokens = tokenize(cleaned)
|
| 117 |
+
return tokens
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def add_sentence_markers(tokens: List[str]) -> List[str]:
|
| 121 |
+
"""
|
| 122 |
+
Add start/end markers for sentence boundary modeling.
|
| 123 |
+
|
| 124 |
+
For trigram models, we need context at sentence boundaries.
|
| 125 |
+
We add two start tokens to provide full context for the first word.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
tokens: List of tokens from a sentence.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Tokens with boundary markers.
|
| 132 |
+
"""
|
| 133 |
+
if not tokens:
|
| 134 |
+
return []
|
| 135 |
+
return [START_TOKEN, START_TOKEN] + tokens + [END_TOKEN]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def preprocess_corpus(texts: List[str]) -> List[List[str]]:
|
| 139 |
+
"""
|
| 140 |
+
Preprocess entire corpus for language model training.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
texts: List of raw text strings.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
List of tokenized sentences with boundary markers.
|
| 147 |
+
"""
|
| 148 |
+
processed = []
|
| 149 |
+
for text in texts:
|
| 150 |
+
tokens = preprocess_text(text)
|
| 151 |
+
if tokens: # Skip empty results
|
| 152 |
+
marked = add_sentence_markers(tokens)
|
| 153 |
+
processed.append(marked)
|
| 154 |
+
return processed
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
# Test preprocessing on Nigerian Pidgin examples
|
| 159 |
+
test_texts = [
|
| 160 |
+
"I dey go market, you wan follow?",
|
| 161 |
+
"That guy na correct person sha @handle https://example.com",
|
| 162 |
+
"Wetin you dey do? Abi you no sabi?",
|
| 163 |
+
"E don happen before, no be today matter",
|
| 164 |
+
"How far? Everything dey go well?",
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
print("Preprocessing Examples:\n")
|
| 168 |
+
for text in test_texts:
|
| 169 |
+
tokens = preprocess_text(text)
|
| 170 |
+
marked = add_sentence_markers(tokens)
|
| 171 |
+
print(f"Original: {text}")
|
| 172 |
+
print(f"Tokens: {tokens}")
|
| 173 |
+
print(f"Marked: {marked}")
|
| 174 |
+
print()
|
src/trigram_model.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trigram Language Model for Next-Word Prediction.
|
| 3 |
+
|
| 4 |
+
Implements a statistical trigram model with Laplace (add-one) smoothing
|
| 5 |
+
for Nigerian English/Pidgin next-word prediction.
|
| 6 |
+
|
| 7 |
+
Mathematical Foundation:
|
| 8 |
+
P(w_n | w_{n-2}, w_{n-1}) = (C(w_{n-2}, w_{n-1}, w_n) + α) / (C(w_{n-2}, w_{n-1}) + α|V|)
|
| 9 |
+
|
| 10 |
+
Where:
|
| 11 |
+
- C(.) = count of n-gram in training corpus
|
| 12 |
+
- α = smoothing parameter (1.0 for Laplace)
|
| 13 |
+
- |V| = vocabulary size
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from collections import Counter
|
| 17 |
+
from typing import List, Tuple, Dict, Optional
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TrigramLM:
|
| 22 |
+
"""
|
| 23 |
+
Trigram Language Model with Laplace smoothing.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
smoothing: Smoothing parameter (α). Default 1.0 for add-one smoothing.
|
| 27 |
+
unigram_counts: Counter for single word frequencies.
|
| 28 |
+
bigram_counts: Counter for word pair frequencies.
|
| 29 |
+
trigram_counts: Counter for word triple frequencies.
|
| 30 |
+
vocab: Set of all unique words in training corpus.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, smoothing: float = 1.0):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the trigram model.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
smoothing: Laplace smoothing parameter. Higher values provide more
|
| 39 |
+
smoothing for unseen n-grams. Default 1.0 (add-one).
|
| 40 |
+
"""
|
| 41 |
+
self.smoothing = smoothing
|
| 42 |
+
self.unigram_counts: Counter = Counter()
|
| 43 |
+
self.bigram_counts: Counter = Counter()
|
| 44 |
+
self.trigram_counts: Counter = Counter()
|
| 45 |
+
self.vocab: set = set()
|
| 46 |
+
self._total_unigrams: int = 0
|
| 47 |
+
|
| 48 |
+
def train(self, sentences: List[List[str]]) -> None:
|
| 49 |
+
"""
|
| 50 |
+
Train the model by counting n-grams from tokenized sentences.
|
| 51 |
+
|
| 52 |
+
Expects sentences with start/end markers already added:
|
| 53 |
+
['<s>', '<s>', 'word1', 'word2', ..., '</s>']
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
sentences: List of tokenized sentences with boundary markers.
|
| 57 |
+
"""
|
| 58 |
+
for sentence in sentences:
|
| 59 |
+
# Build vocabulary
|
| 60 |
+
self.vocab.update(sentence)
|
| 61 |
+
|
| 62 |
+
# Count unigrams
|
| 63 |
+
for token in sentence:
|
| 64 |
+
self.unigram_counts[token] += 1
|
| 65 |
+
self._total_unigrams += 1
|
| 66 |
+
|
| 67 |
+
# Count bigrams
|
| 68 |
+
for i in range(len(sentence) - 1):
|
| 69 |
+
bigram = (sentence[i], sentence[i + 1])
|
| 70 |
+
self.bigram_counts[bigram] += 1
|
| 71 |
+
|
| 72 |
+
# Count trigrams
|
| 73 |
+
for i in range(len(sentence) - 2):
|
| 74 |
+
trigram = (sentence[i], sentence[i + 1], sentence[i + 2])
|
| 75 |
+
self.trigram_counts[trigram] += 1
|
| 76 |
+
|
| 77 |
+
print(f"Training complete:")
|
| 78 |
+
print(f" Vocabulary size: {len(self.vocab):,}")
|
| 79 |
+
print(f" Unique unigrams: {len(self.unigram_counts):,}")
|
| 80 |
+
print(f" Unique bigrams: {len(self.bigram_counts):,}")
|
| 81 |
+
print(f" Unique trigrams: {len(self.trigram_counts):,}")
|
| 82 |
+
|
| 83 |
+
def probability(self, w3: str, w1: str, w2: str) -> float:
|
| 84 |
+
"""
|
| 85 |
+
Compute P(w3 | w1, w2) with Laplace smoothing.
|
| 86 |
+
|
| 87 |
+
Formula:
|
| 88 |
+
P(w3|w1,w2) = (C(w1,w2,w3) + α) / (C(w1,w2) + α|V|)
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
w3: The word to predict.
|
| 92 |
+
w1: First context word (two positions before w3).
|
| 93 |
+
w2: Second context word (one position before w3).
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Conditional probability P(w3 | w1, w2).
|
| 97 |
+
"""
|
| 98 |
+
trigram_count = self.trigram_counts.get((w1, w2, w3), 0)
|
| 99 |
+
bigram_count = self.bigram_counts.get((w1, w2), 0)
|
| 100 |
+
vocab_size = len(self.vocab)
|
| 101 |
+
|
| 102 |
+
# Laplace smoothing
|
| 103 |
+
numerator = trigram_count + self.smoothing
|
| 104 |
+
denominator = bigram_count + (self.smoothing * vocab_size)
|
| 105 |
+
|
| 106 |
+
return numerator / denominator if denominator > 0 else 0.0
|
| 107 |
+
|
| 108 |
+
def log_probability(self, w3: str, w1: str, w2: str) -> float:
|
| 109 |
+
"""
|
| 110 |
+
Compute log P(w3 | w1, w2) for numerical stability.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
w3: The word to predict.
|
| 114 |
+
w1: First context word.
|
| 115 |
+
w2: Second context word.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Log probability.
|
| 119 |
+
"""
|
| 120 |
+
prob = self.probability(w3, w1, w2)
|
| 121 |
+
return math.log(prob) if prob > 0 else float('-inf')
|
| 122 |
+
|
| 123 |
+
def predict_next_words(
|
| 124 |
+
self,
|
| 125 |
+
context: str,
|
| 126 |
+
top_k: int = 5,
|
| 127 |
+
exclude_special: bool = True
|
| 128 |
+
) -> List[Tuple[str, float]]:
|
| 129 |
+
"""
|
| 130 |
+
Predict the top-k most likely next words given a context.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
context: Input text (will use last two words as context).
|
| 134 |
+
top_k: Number of predictions to return.
|
| 135 |
+
exclude_special: If True, exclude <s> and </s> from predictions.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
List of (word, probability) tuples, sorted by probability descending.
|
| 139 |
+
"""
|
| 140 |
+
# Tokenize and extract last two words
|
| 141 |
+
words = context.lower().split()
|
| 142 |
+
|
| 143 |
+
if len(words) == 0:
|
| 144 |
+
w1, w2 = '<s>', '<s>'
|
| 145 |
+
elif len(words) == 1:
|
| 146 |
+
w1, w2 = '<s>', words[0]
|
| 147 |
+
else:
|
| 148 |
+
w1, w2 = words[-2], words[-1]
|
| 149 |
+
|
| 150 |
+
# Compute probability for each word in vocabulary
|
| 151 |
+
candidates = []
|
| 152 |
+
for word in self.vocab:
|
| 153 |
+
if exclude_special and word in ('<s>', '</s>'):
|
| 154 |
+
continue
|
| 155 |
+
prob = self.probability(word, w1, w2)
|
| 156 |
+
candidates.append((word, prob))
|
| 157 |
+
|
| 158 |
+
# Sort by probability descending
|
| 159 |
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
| 160 |
+
|
| 161 |
+
return candidates[:top_k]
|
| 162 |
+
|
| 163 |
+
def sentence_probability(self, tokens: List[str]) -> float:
|
| 164 |
+
"""
|
| 165 |
+
Compute the probability of a sentence.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
tokens: Tokenized sentence WITH start/end markers.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Log probability of the sentence.
|
| 172 |
+
"""
|
| 173 |
+
if len(tokens) < 3:
|
| 174 |
+
return float('-inf')
|
| 175 |
+
|
| 176 |
+
log_prob = 0.0
|
| 177 |
+
for i in range(2, len(tokens)):
|
| 178 |
+
log_prob += self.log_probability(tokens[i], tokens[i-2], tokens[i-1])
|
| 179 |
+
|
| 180 |
+
return log_prob
|
| 181 |
+
|
| 182 |
+
def perplexity(self, sentences: List[List[str]]) -> float:
|
| 183 |
+
"""
|
| 184 |
+
Compute perplexity on a set of sentences.
|
| 185 |
+
|
| 186 |
+
Perplexity = exp(-1/N * sum(log P(w_i | w_{i-2}, w_{i-1})))
|
| 187 |
+
|
| 188 |
+
Lower perplexity = better model fit.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
sentences: List of tokenized sentences with boundary markers.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Perplexity score.
|
| 195 |
+
"""
|
| 196 |
+
total_log_prob = 0.0
|
| 197 |
+
total_words = 0
|
| 198 |
+
|
| 199 |
+
for sentence in sentences:
|
| 200 |
+
if len(sentence) < 3:
|
| 201 |
+
continue
|
| 202 |
+
for i in range(2, len(sentence)):
|
| 203 |
+
total_log_prob += self.log_probability(
|
| 204 |
+
sentence[i], sentence[i-2], sentence[i-1]
|
| 205 |
+
)
|
| 206 |
+
total_words += 1
|
| 207 |
+
|
| 208 |
+
if total_words == 0:
|
| 209 |
+
return float('inf')
|
| 210 |
+
|
| 211 |
+
avg_log_prob = total_log_prob / total_words
|
| 212 |
+
return math.exp(-avg_log_prob)
|
| 213 |
+
|
| 214 |
+
def get_context_distribution(
|
| 215 |
+
self,
|
| 216 |
+
w1: str,
|
| 217 |
+
w2: str,
|
| 218 |
+
top_k: Optional[int] = None
|
| 219 |
+
) -> List[Tuple[str, float]]:
|
| 220 |
+
"""
|
| 221 |
+
Get the probability distribution for a specific bigram context.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
w1: First context word.
|
| 225 |
+
w2: Second context word.
|
| 226 |
+
top_k: If provided, return only top-k predictions.
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
List of (word, probability) tuples.
|
| 230 |
+
"""
|
| 231 |
+
candidates = []
|
| 232 |
+
for word in self.vocab:
|
| 233 |
+
if word not in ('<s>', '</s>'):
|
| 234 |
+
prob = self.probability(word, w1, w2)
|
| 235 |
+
candidates.append((word, prob))
|
| 236 |
+
|
| 237 |
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
| 238 |
+
|
| 239 |
+
if top_k:
|
| 240 |
+
return candidates[:top_k]
|
| 241 |
+
return candidates
|
| 242 |
+
|
| 243 |
+
def get_stats(self) -> Dict[str, int]:
|
| 244 |
+
"""
|
| 245 |
+
Get model statistics.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Dictionary of statistics.
|
| 249 |
+
"""
|
| 250 |
+
return {
|
| 251 |
+
'vocab_size': len(self.vocab),
|
| 252 |
+
'unique_unigrams': len(self.unigram_counts),
|
| 253 |
+
'unique_bigrams': len(self.bigram_counts),
|
| 254 |
+
'unique_trigrams': len(self.trigram_counts),
|
| 255 |
+
'total_tokens': self._total_unigrams,
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
# Quick test with sample data
|
| 261 |
+
sample_sentences = [
|
| 262 |
+
['<s>', '<s>', 'i', 'dey', 'go', 'market', '</s>'],
|
| 263 |
+
['<s>', '<s>', 'i', 'dey', 'come', 'back', '</s>'],
|
| 264 |
+
['<s>', '<s>', 'you', 'dey', 'go', 'where', '?', '</s>'],
|
| 265 |
+
['<s>', '<s>', 'how', 'far', '?', '</s>'],
|
| 266 |
+
['<s>', '<s>', 'e', 'don', 'happen', '</s>'],
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
model = TrigramLM(smoothing=1.0)
|
| 270 |
+
model.train(sample_sentences)
|
| 271 |
+
|
| 272 |
+
print("\nTest Predictions:")
|
| 273 |
+
contexts = ["i dey", "you dey", "how"]
|
| 274 |
+
for ctx in contexts:
|
| 275 |
+
preds = model.predict_next_words(ctx, top_k=3)
|
| 276 |
+
print(f" '{ctx}' -> {preds}")
|
src/utils.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for the next-word prediction system.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def format_predictions(predictions: List[Tuple[str, float]], show_percent: bool = True) -> str:
|
| 10 |
+
"""
|
| 11 |
+
Format prediction results for display.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
predictions: List of (word, probability) tuples.
|
| 15 |
+
show_percent: If True, show as percentage.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Formatted string.
|
| 19 |
+
"""
|
| 20 |
+
lines = []
|
| 21 |
+
for word, prob in predictions:
|
| 22 |
+
if show_percent:
|
| 23 |
+
lines.append(f" {word}: {prob*100:.2f}%")
|
| 24 |
+
else:
|
| 25 |
+
lines.append(f" {word}: {prob:.6f}")
|
| 26 |
+
return "\n".join(lines)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def calculate_entropy(probabilities: List[float]) -> float:
|
| 30 |
+
"""
|
| 31 |
+
Calculate entropy of a probability distribution.
|
| 32 |
+
|
| 33 |
+
H(X) = -sum(p * log2(p))
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
probabilities: List of probabilities.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Entropy in bits.
|
| 40 |
+
"""
|
| 41 |
+
entropy = 0.0
|
| 42 |
+
for p in probabilities:
|
| 43 |
+
if p > 0:
|
| 44 |
+
entropy -= p * math.log2(p)
|
| 45 |
+
return entropy
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def top_k_accuracy(
|
| 49 |
+
model,
|
| 50 |
+
test_sentences: List[List[str]],
|
| 51 |
+
k: int = 5
|
| 52 |
+
) -> float:
|
| 53 |
+
"""
|
| 54 |
+
Calculate top-k accuracy on test data.
|
| 55 |
+
|
| 56 |
+
Measures what fraction of true next words appear in top-k predictions.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model: Trained TrigramLM instance.
|
| 60 |
+
test_sentences: List of tokenized sentences with markers.
|
| 61 |
+
k: Number of top predictions to consider.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Accuracy as fraction between 0 and 1.
|
| 65 |
+
"""
|
| 66 |
+
correct = 0
|
| 67 |
+
total = 0
|
| 68 |
+
|
| 69 |
+
for sentence in test_sentences:
|
| 70 |
+
if len(sentence) < 3:
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
for i in range(2, len(sentence)):
|
| 74 |
+
w1, w2 = sentence[i-2], sentence[i-1]
|
| 75 |
+
true_word = sentence[i]
|
| 76 |
+
|
| 77 |
+
# Get top-k predictions
|
| 78 |
+
preds = model.get_context_distribution(w1, w2, top_k=k)
|
| 79 |
+
pred_words = [w for w, _ in preds]
|
| 80 |
+
|
| 81 |
+
if true_word in pred_words:
|
| 82 |
+
correct += 1
|
| 83 |
+
total += 1
|
| 84 |
+
|
| 85 |
+
return correct / total if total > 0 else 0.0
|