umag_backend / backend /embeddings.py
amyxcao's picture
Deploy FastAPI backend
9c8703c
raw
history blame
2.67 kB
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from transformers import AutoModelForMaskedLM
from backend.interfaces import BaseEmbeddingModel
class DefaultDenseEmbeddingModel(BaseEmbeddingModel):
def __init__(self, config):
super().__init__(config)
# Initialize your dense embedding model here
self.model_name = "BAAI/BGE-VL-base" # or "BAAI/BGE-VL-large"
self.model = AutoModel.from_pretrained(
self.model_name, trust_remote_code=True
).to(self.device)
self.preprocessor = AutoProcessor.from_pretrained(
self.model_name, trust_remote_code=True
)
def encode_text(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []
inputs = self.preprocessor(
text=texts, return_tensors="pt", truncation=True, padding=True
).to(self.device)
return self.model.get_text_features(**inputs).cpu().tolist()
def encode_image(self, images: list[str] | list[Image.Image]) -> list[float]:
if not images:
return []
if isinstance(images[0], str):
images = [Image.open(image_path).convert("RGB") for image_path in images]
inputs = self.preprocessor(images=images, return_tensors="pt").to(self.device)
return self.model.get_image_features(**inputs).cpu().tolist()
class DefaultSparseEmbeddingModel(BaseEmbeddingModel):
def __init__(self, config):
super().__init__(config)
# Initialize your sparse embedding model here
self.model_name = "naver/splade-v3"
self.model = AutoModelForMaskedLM.from_pretrained(self.model_name).to(
self.device
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
def encode_text(self, texts: list[str]) -> list[dict]:
if not texts:
return []
tokens = self.tokenizer(
texts, return_tensors="pt", truncation=True, padding=True
).to(self.device)
outputs = self.model(**tokens)
sparse_embedding = (
torch.max(
torch.log(1 + torch.relu(outputs.logits))
* tokens.attention_mask.unsqueeze(-1),
dim=1,
)[0]
.detach()
.cpu()
)
# convert to pinecone sparse format
res = []
for i in range(len(sparse_embedding)):
indices = sparse_embedding[i].nonzero().squeeze().tolist()
values = sparse_embedding[i, indices].tolist()
res.append({"indices": indices, "values": values})
return res