Feature Extraction
Transformers
Safetensors
English
diffretriever
information-retrieval
dense-retrieval
sparse-retrieval
colbert
diffusion-language-model
lora
custom_code
Instructions to use ielabgroup/diffretriever-llada-8b-single with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ielabgroup/diffretriever-llada-8b-single with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="ielabgroup/diffretriever-llada-8b-single", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ielabgroup/diffretriever-llada-8b-single", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Sparse representation utilities matching original PromptReps. | |
| Reference: https://github.com/ielab/PromptReps | |
| The original filters sparse logits by: | |
| 1. Word-tokenizing the text (lowercased) with NLTK | |
| 2. Removing stopwords and punctuation | |
| 3. Re-tokenizing each content word independently to get clean token IDs | |
| 4. Only keeping logits for those token IDs | |
| """ | |
| import torch | |
| from typing import List, Set, Optional | |
| import string | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| try: | |
| from nltk import word_tokenize | |
| from nltk.corpus import stopwords as _sw_corpus | |
| STOPWORDS = set(_sw_corpus.words('english') + list(string.punctuation)) | |
| except LookupError: | |
| import nltk | |
| nltk.download('punkt_tab', quiet=True) | |
| nltk.download('stopwords', quiet=True) | |
| from nltk import word_tokenize | |
| from nltk.corpus import stopwords as _sw_corpus | |
| STOPWORDS = set(_sw_corpus.words('english') + list(string.punctuation)) | |
| def get_content_token_ids(texts: List[str], tokenizer) -> List[Set[int]]: | |
| """Extract content token IDs from texts (stopwords removed, word-level tokenization). | |
| Matches original PromptReps get_valid_tokens_values: | |
| 1. word_tokenize(text.lower()) | |
| 2. Remove stopwords + punctuation | |
| 3. tokenizer.encode(word) for each remaining word | |
| Args: | |
| texts: List of raw text strings. | |
| tokenizer: HuggingFace tokenizer. | |
| Returns: | |
| List of sets of token IDs, one per text. | |
| """ | |
| all_token_ids = [] | |
| for text in texts: | |
| words = [w for w in word_tokenize(text.lower()) | |
| if w not in STOPWORDS] | |
| if words: | |
| # Batch encode all words at once (much faster than per-word encode) | |
| batch_ids = tokenizer(words, add_special_tokens=False)['input_ids'] | |
| token_ids = set() | |
| for ids in batch_ids: | |
| token_ids.update(ids) | |
| else: | |
| token_ids = set() | |
| all_token_ids.append(token_ids) | |
| return all_token_ids | |
| def filter_sparse( | |
| sparse: torch.Tensor, | |
| content_token_ids: List[Set[int]], | |
| exclude_ids: Optional[List[int]] = None, | |
| ) -> torch.Tensor: | |
| """Filter sparse logits to only keep content tokens. | |
| Args: | |
| sparse: [batch, vocab] sparse logit tensor. | |
| content_token_ids: List of sets of valid token IDs per example. | |
| exclude_ids: Token IDs to always exclude (e.g., MASK token). | |
| Returns: | |
| Filtered sparse tensor (same shape, non-content entries zeroed). | |
| """ | |
| # Build (row, col) index pairs for all content tokens across batch | |
| rows, cols = [], [] | |
| for i in range(sparse.size(0)): | |
| if content_token_ids[i]: | |
| ids = list(content_token_ids[i]) | |
| rows.extend([i] * len(ids)) | |
| cols.extend(ids) | |
| if rows: | |
| mask = torch.zeros_like(sparse) | |
| mask[rows, cols] = 1.0 | |
| else: | |
| mask = torch.zeros_like(sparse) | |
| if exclude_ids: | |
| for eid in exclude_ids: | |
| mask[:, eid] = 0.0 | |
| return sparse * mask | |