Feature Extraction
Transformers
Safetensors
English
diffretriever
information-retrieval
dense-retrieval
sparse-retrieval
colbert
diffusion-language-model
lora
custom_code
Instructions to use ielabgroup/diffretriever-llada-8b-single with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ielabgroup/diffretriever-llada-8b-single with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="ielabgroup/diffretriever-llada-8b-single", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ielabgroup/diffretriever-llada-8b-single", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 3,008 Bytes
a7c784c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | """
Sparse representation utilities matching original PromptReps.
Reference: https://github.com/ielab/PromptReps
The original filters sparse logits by:
1. Word-tokenizing the text (lowercased) with NLTK
2. Removing stopwords and punctuation
3. Re-tokenizing each content word independently to get clean token IDs
4. Only keeping logits for those token IDs
"""
import torch
from typing import List, Set, Optional
import string
import logging
logger = logging.getLogger(__name__)
try:
from nltk import word_tokenize
from nltk.corpus import stopwords as _sw_corpus
STOPWORDS = set(_sw_corpus.words('english') + list(string.punctuation))
except LookupError:
import nltk
nltk.download('punkt_tab', quiet=True)
nltk.download('stopwords', quiet=True)
from nltk import word_tokenize
from nltk.corpus import stopwords as _sw_corpus
STOPWORDS = set(_sw_corpus.words('english') + list(string.punctuation))
def get_content_token_ids(texts: List[str], tokenizer) -> List[Set[int]]:
"""Extract content token IDs from texts (stopwords removed, word-level tokenization).
Matches original PromptReps get_valid_tokens_values:
1. word_tokenize(text.lower())
2. Remove stopwords + punctuation
3. tokenizer.encode(word) for each remaining word
Args:
texts: List of raw text strings.
tokenizer: HuggingFace tokenizer.
Returns:
List of sets of token IDs, one per text.
"""
all_token_ids = []
for text in texts:
words = [w for w in word_tokenize(text.lower())
if w not in STOPWORDS]
if words:
# Batch encode all words at once (much faster than per-word encode)
batch_ids = tokenizer(words, add_special_tokens=False)['input_ids']
token_ids = set()
for ids in batch_ids:
token_ids.update(ids)
else:
token_ids = set()
all_token_ids.append(token_ids)
return all_token_ids
def filter_sparse(
sparse: torch.Tensor,
content_token_ids: List[Set[int]],
exclude_ids: Optional[List[int]] = None,
) -> torch.Tensor:
"""Filter sparse logits to only keep content tokens.
Args:
sparse: [batch, vocab] sparse logit tensor.
content_token_ids: List of sets of valid token IDs per example.
exclude_ids: Token IDs to always exclude (e.g., MASK token).
Returns:
Filtered sparse tensor (same shape, non-content entries zeroed).
"""
# Build (row, col) index pairs for all content tokens across batch
rows, cols = [], []
for i in range(sparse.size(0)):
if content_token_ids[i]:
ids = list(content_token_ids[i])
rows.extend([i] * len(ids))
cols.extend(ids)
if rows:
mask = torch.zeros_like(sparse)
mask[rows, cols] = 1.0
else:
mask = torch.zeros_like(sparse)
if exclude_ids:
for eid in exclude_ids:
mask[:, eid] = 0.0
return sparse * mask
|