|
|
from typing import List |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
class EmbeddingScorer: |
|
|
""" |
|
|
A class for performing semantic search using embeddings. |
|
|
Uses the gte-multilingual-base model from Alibaba-NLP. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name='Alibaba-NLP/gte-multilingual-base'): |
|
|
""" |
|
|
Initialize the EmbeddingScorer with the specified model. |
|
|
|
|
|
Args: |
|
|
model_name (str): Name of the model to use. |
|
|
""" |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
|
self.dimension = 768 |
|
|
|
|
|
def score_method(self, query: str, methods: List[dict]) -> List[dict]: |
|
|
""" |
|
|
Calculate similarity between a query and a list of methods. |
|
|
|
|
|
Args: |
|
|
query (str): The query sentence. |
|
|
methods (list): List of method dictionaries to compare against the query. |
|
|
|
|
|
Returns: |
|
|
list: List of similarity scores between the query and each method. |
|
|
""" |
|
|
|
|
|
sentences = [f"{method['method']}: {method.get('description', '')}" for method in methods] |
|
|
texts = [query] + sentences |
|
|
|
|
|
|
|
|
batch_dict = self.tokenizer(texts, max_length=8192, padding=True, truncation=True, return_tensors='pt') |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**batch_dict) |
|
|
|
|
|
|
|
|
embeddings = outputs.last_hidden_state[:, 0][:self.dimension] |
|
|
|
|
|
|
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
|
|
|
|
|
query_embedding = embeddings[0].unsqueeze(0) |
|
|
method_embeddings = embeddings[1:] |
|
|
|
|
|
|
|
|
similarities = (query_embedding @ method_embeddings.T) * 100 |
|
|
similarities = similarities.squeeze().tolist() |
|
|
|
|
|
|
|
|
if not isinstance(similarities, list): |
|
|
similarities = [similarities] |
|
|
|
|
|
|
|
|
result = [] |
|
|
for i, similarity in enumerate(similarities, start=1): |
|
|
result.append({ |
|
|
"method_index": i, |
|
|
"score": float(similarity) |
|
|
}) |
|
|
|
|
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
|
es = EmbeddingScorer() |
|
|
print(es.score_method("How to solve the problem of the user", [{"method": "Method 1", "description": "Description 1"}, {"method": "Method 2", "description": "Description 2"}])) |
|
|
|