File size: 2,115 Bytes
ad18db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Utility functions for the next-word prediction system.
"""

from typing import List, Tuple
import math


def format_predictions(predictions: List[Tuple[str, float]], show_percent: bool = True) -> str:
    """
    Format prediction results for display.
    
    Args:
        predictions: List of (word, probability) tuples.
        show_percent: If True, show as percentage.
        
    Returns:
        Formatted string.
    """
    lines = []
    for word, prob in predictions:
        if show_percent:
            lines.append(f"  {word}: {prob*100:.2f}%")
        else:
            lines.append(f"  {word}: {prob:.6f}")
    return "\n".join(lines)


def calculate_entropy(probabilities: List[float]) -> float:
    """
    Calculate entropy of a probability distribution.
    
    H(X) = -sum(p * log2(p))
    
    Args:
        probabilities: List of probabilities.
        
    Returns:
        Entropy in bits.
    """
    entropy = 0.0
    for p in probabilities:
        if p > 0:
            entropy -= p * math.log2(p)
    return entropy


def top_k_accuracy(
    model, 
    test_sentences: List[List[str]], 
    k: int = 5
) -> float:
    """
    Calculate top-k accuracy on test data.
    
    Measures what fraction of true next words appear in top-k predictions.
    
    Args:
        model: Trained TrigramLM instance.
        test_sentences: List of tokenized sentences with markers.
        k: Number of top predictions to consider.
        
    Returns:
        Accuracy as fraction between 0 and 1.
    """
    correct = 0
    total = 0
    
    for sentence in test_sentences:
        if len(sentence) < 3:
            continue
        
        for i in range(2, len(sentence)):
            w1, w2 = sentence[i-2], sentence[i-1]
            true_word = sentence[i]
            
            # Get top-k predictions
            preds = model.get_context_distribution(w1, w2, top_k=k)
            pred_words = [w for w, _ in preds]
            
            if true_word in pred_words:
                correct += 1
            total += 1
    
    return correct / total if total > 0 else 0.0