Spaces:
Running
Running
Batching
Browse files* small batch size processing for SPLADE re-ranking to work within CPU limitations
* reduce default number of context documents to 5
ask_candid/retrieval/elastic.py
CHANGED
|
@@ -299,7 +299,7 @@ def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
|
|
| 299 |
def reranker(
|
| 300 |
query_results: Iterable[ElasticHitsResult],
|
| 301 |
search_text: Optional[str] = None,
|
| 302 |
-
max_num_results: int =
|
| 303 |
) -> Iterator[ElasticHitsResult]:
|
| 304 |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
| 305 |
This will shuffle results
|
|
|
|
| 299 |
def reranker(
|
| 300 |
query_results: Iterable[ElasticHitsResult],
|
| 301 |
search_text: Optional[str] = None,
|
| 302 |
+
max_num_results: int = 5
|
| 303 |
) -> Iterator[ElasticHitsResult]:
|
| 304 |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
| 305 |
This will shuffle results
|
ask_candid/retrieval/sparse_lexical.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
| 1 |
from typing import List, Dict
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
|
|
| 4 |
from torch.nn import functional as F
|
| 5 |
import torch
|
| 6 |
|
| 7 |
|
| 8 |
class SpladeEncoder:
|
|
|
|
| 9 |
|
| 10 |
def __init__(self):
|
| 11 |
model_id = "naver/splade-v3"
|
|
@@ -16,13 +20,16 @@ class SpladeEncoder:
|
|
| 16 |
|
| 17 |
@torch.no_grad()
|
| 18 |
def forward(self, texts: List[str]):
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def query_reranking(self, query: str, documents: List[str]):
|
| 28 |
vec = self.forward([query, *documents])
|
|
@@ -31,7 +38,7 @@ class SpladeEncoder:
|
|
| 31 |
return (xQ * xD).sum(dim=-1).cpu().tolist()
|
| 32 |
|
| 33 |
def token_expand(self, query: str) -> Dict[str, float]:
|
| 34 |
-
vec = self.forward([query])
|
| 35 |
cols = vec.nonzero().squeeze().cpu().tolist()
|
| 36 |
weights = vec[cols].cpu().tolist()
|
| 37 |
|
|
|
|
| 1 |
from typing import List, Dict
|
| 2 |
|
| 3 |
+
from tqdm.auto import tqdm
|
| 4 |
+
|
| 5 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
from torch.nn import functional as F
|
| 8 |
import torch
|
| 9 |
|
| 10 |
|
| 11 |
class SpladeEncoder:
|
| 12 |
+
batch_size = 4
|
| 13 |
|
| 14 |
def __init__(self):
|
| 15 |
model_id = "naver/splade-v3"
|
|
|
|
| 20 |
|
| 21 |
@torch.no_grad()
|
| 22 |
def forward(self, texts: List[str]):
|
| 23 |
+
vectors = []
|
| 24 |
+
for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Re-ranking"):
|
| 25 |
+
tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
|
| 26 |
+
output = self.model(**tokens)
|
| 27 |
+
vec = torch.max(
|
| 28 |
+
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
|
| 29 |
+
dim=1
|
| 30 |
+
)[0].squeeze()
|
| 31 |
+
vectors.append(vec)
|
| 32 |
+
return torch.vstack(vectors)
|
| 33 |
|
| 34 |
def query_reranking(self, query: str, documents: List[str]):
|
| 35 |
vec = self.forward([query, *documents])
|
|
|
|
| 38 |
return (xQ * xD).sum(dim=-1).cpu().tolist()
|
| 39 |
|
| 40 |
def token_expand(self, query: str) -> Dict[str, float]:
|
| 41 |
+
vec = self.forward([query]).squeeze()
|
| 42 |
cols = vec.nonzero().squeeze().cpu().tolist()
|
| 43 |
weights = vec[cols].cpu().tolist()
|
| 44 |
|