bge-base-en-v1.5 / handler.py
pashaa's picture
Upload folder using huggingface_hub
057446f verified
"""
Custom handler for BGE dense retrieval on HuggingFace Inference Endpoints.
Returns dense embeddings for queries and passages.
Key difference from ANCE: BGE requires an instruction prefix on queries
for retrieval tasks. Passages are encoded without any prefix.
"""
from typing import Any, Dict, List, Union
import torch
from transformers import AutoModel, AutoTokenizer
# BGE query instruction prefix
QUERY_INSTRUCTION = "Represent this sentence for searching relevant passages: "
class EndpointHandler:
"""Handler for BGE embedding generation."""
def __init__(self, path: str = ""):
"""Initialize the model and tokenizer."""
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModel.from_pretrained(path)
self.model.eval()
# Move to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
print(f"BGE loaded on {self.device}")
def __call__(self, data: Dict[str, Any]) -> Union[List[List[float]], Dict[str, Any]]:
"""
Process inference requests.
Accepts:
- {"inputs": "text"} - single text (encoded as passage, no prefix)
- {"inputs": ["text1", "text2", ...]} - batch of texts (encoded as passages)
- {"inputs": "text", "is_query": true} - single text with query prefix
- {"inputs": [...], "is_query": true} - batch with query prefix
- {"query": "...", "passages": ["...", ...]} - query + passages (returns similarity scores)
Returns:
- {"embeddings": [...]} - list of embeddings
- Or {"scores": [...]} if query + passages provided
"""
inputs = data.get("inputs", None)
is_query = data.get("is_query", False)
query = data.get("query", None)
passages = data.get("passages", None)
# Mode 1: Query + Passages -> return similarity scores
if query is not None and passages is not None:
# Encode query with instruction prefix
query_with_prefix = f"{QUERY_INSTRUCTION}{query}"
query_emb = self._encode([query_with_prefix])[0]
# Encode passages without prefix
passage_embs = self._encode(passages)
# Compute cosine similarities
scores = []
for p_emb in passage_embs:
score = self._cosine_similarity(query_emb, p_emb)
scores.append(score)
return {"scores": scores}
# Mode 2: Just inputs -> return embeddings
if inputs is None:
return {"error": "No inputs provided. Use 'inputs' or 'query'+'passages'."}
if isinstance(inputs, str):
inputs = [inputs]
# Apply query instruction prefix if flagged as queries
if is_query:
inputs = [f"{QUERY_INSTRUCTION}{text}" for text in inputs]
embeddings = self._encode(inputs)
return {"embeddings": embeddings}
def _encode(self, texts: List[str], max_length: int = 512) -> List[List[float]]:
"""Encode texts into embeddings."""
# Tokenize
encoded = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt"
).to(self.device)
# Get embeddings
with torch.no_grad():
outputs = self.model(**encoded)
# Use CLS token embedding
embeddings = outputs.last_hidden_state[:, 0, :]
# Normalize embeddings
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.cpu().tolist()
def _cosine_similarity(self, emb1: List[float], emb2: List[float]) -> float:
"""Compute cosine similarity between two embeddings."""
import math
dot = sum(a * b for a, b in zip(emb1, emb2))
norm1 = math.sqrt(sum(a * a for a in emb1))
norm2 = math.sqrt(sum(b * b for b in emb2))
if norm1 == 0 or norm2 == 0:
return 0.0
return dot / (norm1 * norm2)