| import os
|
| import torch
|
| import json
|
| from typing import Dict, List, Union, Optional, Any
|
| from PIL import Image
|
| from transformers import AutoConfig, AutoTokenizer
|
| from custom_st import Transformer
|
|
|
| class InferenceEmbeddings:
|
| def __init__(self, model_path: str):
|
| """
|
| Initialize the embedding model for inference
|
|
|
| Args:
|
| model_path: Path to the model directory
|
| """
|
| self.model_path = model_path
|
| self.model = Transformer(
|
| model_name_or_path=model_path,
|
| model_args={"default_task": "retrieval", "trust_remote_code": True},
|
| trust_remote_code=True
|
| )
|
| self.model.eval()
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| self.model.to(self.device)
|
|
|
| def encode_text(self,
|
| texts: List[str],
|
| task: str = "retrieval",
|
| prompt_name: Optional[str] = None,
|
| truncate_dim: Optional[int] = None,
|
| return_multivector: bool = False,
|
| max_length: Optional[int] = None,
|
| batch_size: int = 32) -> torch.Tensor:
|
| """
|
| Encode text inputs to embeddings
|
|
|
| Args:
|
| texts: List of text inputs to encode
|
| task: Task for which to generate embeddings (retrieval, text-matching, code)
|
| prompt_name: Optional prompt type (query, passage)
|
| truncate_dim: Optional dimension to truncate embeddings to
|
| return_multivector: Whether to return multi-vector embeddings
|
| max_length: Maximum token length
|
| batch_size: Batch size for encoding
|
|
|
| Returns:
|
| Tensor of embeddings
|
| """
|
| if prompt_name:
|
|
|
| if prompt_name == "query":
|
| texts = [f"Query: {text}" for text in texts]
|
| elif prompt_name == "passage":
|
| texts = [f"Passage: {text}" for text in texts]
|
|
|
| embeddings = []
|
| for i in range(0, len(texts), batch_size):
|
| batch_texts = texts[i:i+batch_size]
|
| features = self.model.tokenize(batch_texts)
|
|
|
|
|
| for key, value in features.items():
|
| if isinstance(value, torch.Tensor):
|
| features[key] = value.to(self.device)
|
|
|
| with torch.no_grad():
|
| outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
|
| batch_embeddings = outputs.get("sentence_embedding", None)
|
|
|
| if batch_embeddings is not None:
|
| embeddings.append(batch_embeddings.cpu())
|
|
|
| if embeddings:
|
| return torch.cat(embeddings, dim=0)
|
| else:
|
| raise RuntimeError("Failed to generate embeddings")
|
|
|
| def encode_image(self,
|
| images: List[Union[str, Image.Image]],
|
| task: str = "retrieval",
|
| truncate_dim: Optional[int] = None,
|
| return_multivector: bool = False,
|
| max_pixels: Optional[int] = None,
|
| batch_size: int = 8) -> torch.Tensor:
|
| """
|
| Encode image inputs to embeddings
|
|
|
| Args:
|
| images: List of image inputs (file paths, URLs, or PIL Images)
|
| task: Task for which to generate embeddings
|
| truncate_dim: Optional dimension to truncate embeddings to
|
| return_multivector: Whether to return multi-vector embeddings
|
| max_pixels: Maximum number of pixels for image resizing
|
| batch_size: Batch size for encoding
|
|
|
| Returns:
|
| Tensor of embeddings
|
| """
|
| embeddings = []
|
| for i in range(0, len(images), batch_size):
|
| batch_images = images[i:i+batch_size]
|
| features = self.model.tokenize(batch_images)
|
|
|
|
|
| for key, value in features.items():
|
| if isinstance(value, torch.Tensor):
|
| features[key] = value.to(self.device)
|
|
|
| with torch.no_grad():
|
| outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
|
| batch_embeddings = outputs.get("sentence_embedding", None)
|
|
|
| if batch_embeddings is not None:
|
| embeddings.append(batch_embeddings.cpu())
|
|
|
| if embeddings:
|
| return torch.cat(embeddings, dim=0)
|
| else:
|
| raise RuntimeError("Failed to generate embeddings")
|
|
|
| def load_model(model_path: str):
|
| """
|
| Load the embedding model for inference
|
|
|
| Args:
|
| model_path: Path to the model directory
|
|
|
| Returns:
|
| Loaded model instance
|
| """
|
| return InferenceEmbeddings(model_path)
|
|
|