File size: 3,008 Bytes
a7c784c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Sparse representation utilities matching original PromptReps.

Reference: https://github.com/ielab/PromptReps
The original filters sparse logits by:
1. Word-tokenizing the text (lowercased) with NLTK
2. Removing stopwords and punctuation
3. Re-tokenizing each content word independently to get clean token IDs
4. Only keeping logits for those token IDs
"""

import torch
from typing import List, Set, Optional
import string
import logging

logger = logging.getLogger(__name__)

try:
    from nltk import word_tokenize
    from nltk.corpus import stopwords as _sw_corpus
    STOPWORDS = set(_sw_corpus.words('english') + list(string.punctuation))
except LookupError:
    import nltk
    nltk.download('punkt_tab', quiet=True)
    nltk.download('stopwords', quiet=True)
    from nltk import word_tokenize
    from nltk.corpus import stopwords as _sw_corpus
    STOPWORDS = set(_sw_corpus.words('english') + list(string.punctuation))


def get_content_token_ids(texts: List[str], tokenizer) -> List[Set[int]]:
    """Extract content token IDs from texts (stopwords removed, word-level tokenization).

    Matches original PromptReps get_valid_tokens_values:
    1. word_tokenize(text.lower())
    2. Remove stopwords + punctuation
    3. tokenizer.encode(word) for each remaining word

    Args:
        texts: List of raw text strings.
        tokenizer: HuggingFace tokenizer.

    Returns:
        List of sets of token IDs, one per text.
    """
    all_token_ids = []
    for text in texts:
        words = [w for w in word_tokenize(text.lower())
                 if w not in STOPWORDS]
        if words:
            # Batch encode all words at once (much faster than per-word encode)
            batch_ids = tokenizer(words, add_special_tokens=False)['input_ids']
            token_ids = set()
            for ids in batch_ids:
                token_ids.update(ids)
        else:
            token_ids = set()
        all_token_ids.append(token_ids)
    return all_token_ids


def filter_sparse(
    sparse: torch.Tensor,
    content_token_ids: List[Set[int]],
    exclude_ids: Optional[List[int]] = None,
) -> torch.Tensor:
    """Filter sparse logits to only keep content tokens.

    Args:
        sparse: [batch, vocab] sparse logit tensor.
        content_token_ids: List of sets of valid token IDs per example.
        exclude_ids: Token IDs to always exclude (e.g., MASK token).

    Returns:
        Filtered sparse tensor (same shape, non-content entries zeroed).
    """
    # Build (row, col) index pairs for all content tokens across batch
    rows, cols = [], []
    for i in range(sparse.size(0)):
        if content_token_ids[i]:
            ids = list(content_token_ids[i])
            rows.extend([i] * len(ids))
            cols.extend(ids)
    if rows:
        mask = torch.zeros_like(sparse)
        mask[rows, cols] = 1.0
    else:
        mask = torch.zeros_like(sparse)
    if exclude_ids:
        for eid in exclude_ids:
            mask[:, eid] = 0.0
    return sparse * mask