wshuai190's picture
Add self-contained DiffRetriever (trust_remote_code: code + config + adapter/)
a7c784c verified
Raw
History Blame Contribute Delete
3.01 kB
"""
Sparse representation utilities matching original PromptReps.
Reference: https://github.com/ielab/PromptReps
The original filters sparse logits by:
1. Word-tokenizing the text (lowercased) with NLTK
2. Removing stopwords and punctuation
3. Re-tokenizing each content word independently to get clean token IDs
4. Only keeping logits for those token IDs
"""
import torch
from typing import List, Set, Optional
import string
import logging
logger = logging.getLogger(__name__)
try:
from nltk import word_tokenize
from nltk.corpus import stopwords as _sw_corpus
STOPWORDS = set(_sw_corpus.words('english') + list(string.punctuation))
except LookupError:
import nltk
nltk.download('punkt_tab', quiet=True)
nltk.download('stopwords', quiet=True)
from nltk import word_tokenize
from nltk.corpus import stopwords as _sw_corpus
STOPWORDS = set(_sw_corpus.words('english') + list(string.punctuation))
def get_content_token_ids(texts: List[str], tokenizer) -> List[Set[int]]:
"""Extract content token IDs from texts (stopwords removed, word-level tokenization).
Matches original PromptReps get_valid_tokens_values:
1. word_tokenize(text.lower())
2. Remove stopwords + punctuation
3. tokenizer.encode(word) for each remaining word
Args:
texts: List of raw text strings.
tokenizer: HuggingFace tokenizer.
Returns:
List of sets of token IDs, one per text.
"""
all_token_ids = []
for text in texts:
words = [w for w in word_tokenize(text.lower())
if w not in STOPWORDS]
if words:
# Batch encode all words at once (much faster than per-word encode)
batch_ids = tokenizer(words, add_special_tokens=False)['input_ids']
token_ids = set()
for ids in batch_ids:
token_ids.update(ids)
else:
token_ids = set()
all_token_ids.append(token_ids)
return all_token_ids
def filter_sparse(
sparse: torch.Tensor,
content_token_ids: List[Set[int]],
exclude_ids: Optional[List[int]] = None,
) -> torch.Tensor:
"""Filter sparse logits to only keep content tokens.
Args:
sparse: [batch, vocab] sparse logit tensor.
content_token_ids: List of sets of valid token IDs per example.
exclude_ids: Token IDs to always exclude (e.g., MASK token).
Returns:
Filtered sparse tensor (same shape, non-content entries zeroed).
"""
# Build (row, col) index pairs for all content tokens across batch
rows, cols = [], []
for i in range(sparse.size(0)):
if content_token_ids[i]:
ids = list(content_token_ids[i])
rows.extend([i] * len(ids))
cols.extend(ids)
if rows:
mask = torch.zeros_like(sparse)
mask[rows, cols] = 1.0
else:
mask = torch.zeros_like(sparse)
if exclude_ids:
for eid in exclude_ids:
mask[:, eid] = 0.0
return sparse * mask