| from typing import Dict, List, Union |
| import torch |
| from transformers import AutoModel |
| from custom_tokenizer import CustomPhobertTokenizer |
|
|
|
|
| def mean_pooling(model_output, attention_mask): |
| token_embeddings = model_output[ |
| 0 |
| ] |
| input_mask_expanded = ( |
| attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| ) |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
| input_mask_expanded.sum(1), min=1e-9 |
| ) |
|
|
|
|
| class PreTrainedPipeline: |
| def __init__(self, path="."): |
| self.model = AutoModel.from_pretrained(path) |
| self.tokenizer = CustomPhobertTokenizer.from_pretrained(path) |
|
|
| def __call__(self, inputs: Dict[str, Union[str, List[str]]]) -> List[float]: |
| """ |
| Args: |
| inputs (Dict[str, Union[str, List[str]]]): |
| a dictionary containing a query sentence and a list of key sentences |
| """ |
|
|
| |
| sentences = [inputs["source_sentence"]] + inputs["sentences"] |
|
|
| |
| encoded_input = self.tokenizer( |
| sentences, padding=True, truncation=True, return_tensors="pt" |
| ) |
|
|
| |
| with torch.no_grad(): |
| model_output = self.model(**encoded_input) |
|
|
| |
| sentence_embeddings = mean_pooling( |
| model_output, encoded_input["attention_mask"] |
| ) |
|
|
| |
| query_embedding = sentence_embeddings[0] |
| key_embeddings = sentence_embeddings[1:] |
|
|
| |
| cosine_similarities = torch.nn.functional.cosine_similarity( |
| query_embedding.unsqueeze(0), key_embeddings |
| ) |
|
|
| |
| scores = cosine_similarities.tolist() |
|
|
| return scores |
|
|
|
|
| if __name__ == "__main__": |
| inputs = { |
| "source_sentence": "Anh ấy đang là sinh viên năm cuối", |
| "sentences": [ |
| "Anh ấy học tại Đại học Bách khoa Hà Nội, chuyên ngành Khoa học máy tính", |
| "Anh ấy đang làm việc tại nhà máy sản xuất linh kiện điện tử", |
| "Anh ấy chuẩn bị đi du học nước ngoài", |
| "Anh ấy sắp mở cửa hàng bán mỹ phẩm", |
| "Nhà anh ấy có rất nhiều cây cảnh", |
| ], |
| } |
|
|
| pipeline = PreTrainedPipeline() |
| res = pipeline(inputs) |
|
|