File size: 2,338 Bytes
290f366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
CaptionIQ — Shared Utility Functions
Load captions, image lists, features, and tokenizer from disk.
"""

import pickle
from typing import Dict, List, Set


def load_captions(filepath: str) -> Dict[str, List[str]]:
    """
    Load cleaned captions from file.
    
    Expected format per line:
        image_id<tab>caption text
    
    Returns:
        dict mapping image_id → list of caption strings
    """
    captions = {}
    with open(filepath, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split("\t", 1)
            if len(parts) != 2:
                continue
            image_id, caption = parts
            if image_id not in captions:
                captions[image_id] = []
            captions[image_id].append(caption)
    return captions


def load_image_list(filepath: str) -> Set[str]:
    """
    Load a set of image IDs from a text file (one per line).
    """
    with open(filepath, "r") as f:
        return {line.strip() for line in f if line.strip()}


def load_features(pkl_path: str) -> Dict[str, any]:
    """
    Load pre-extracted image features from a pickle file.
    
    Returns:
        dict mapping image_id → numpy array of shape (4096,)
    """
    with open(pkl_path, "rb") as f:
        return pickle.load(f)


def load_tokenizer(pkl_path: str):
    """
    Load a fitted Keras Tokenizer from a pickle file.
    """
    with open(pkl_path, "rb") as f:
        return pickle.load(f)


def word_for_id(integer: int, tokenizer) -> str:
    """
    Map an integer index back to a word using the tokenizer.
    Uses a cached reverse index for O(1) lookup.
    Returns None if the index is not found or exceeds num_words.
    """
    if not hasattr(tokenizer, '_reverse_index'):
        tokenizer._reverse_index = {
            idx: word for word, idx in tokenizer.word_index.items()
        }
    # Respect vocab filtering
    if tokenizer.num_words is not None and integer >= tokenizer.num_words:
        return None
    return tokenizer._reverse_index.get(integer, None)


def get_vocab_size(tokenizer) -> int:
    """Get vocabulary size, respecting num_words filter if set."""
    if tokenizer.num_words is not None:
        return tokenizer.num_words
    return len(tokenizer.word_index) + 1