Spaces:
Running
Running
| import dataclasses | |
| import math | |
| from typing import List, Optional | |
| import torch | |
| from pymilvus import MilvusClient, connections | |
| from transformers import AutoModel, AutoTokenizer | |
| from credentials import get_token | |
| class MilvusParams: | |
| uri: str | |
| token: str | |
| db_name: str | |
| collection_name: str | |
| class ProteinSearchEngine: | |
| n_dims = 128 | |
| dist_metric = "euclidean" | |
| max_lengths = (30, 300) | |
| def __init__(self, milvus_params: MilvusParams, model_repo: str): | |
| self.model_repo = model_repo | |
| self.milvus_params = milvus_params | |
| connections.connect( | |
| "default", | |
| uri=milvus_params.uri, | |
| token=milvus_params.token, | |
| db_name=milvus_params.db_name, | |
| ) | |
| self.client = MilvusClient(uri=milvus_params.uri, token=milvus_params.token) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_repo, use_auth_token=get_token() | |
| ) | |
| self.model = AutoModel.from_pretrained( | |
| self.model_repo, use_auth_token=get_token(), trust_remote_code=True | |
| ) | |
| self.model.eval() | |
| def search_by_sequence(self, sequence: str, n: int, organism: Optional[str] = None): | |
| max_length = self.max_lengths[0] | |
| vec = self._embed_sequence(max_length, sequence) | |
| response = self.search(vec, n_results=n, is_peptide=False, organism=organism) | |
| search_results = self._format_search_results(response) | |
| return search_results | |
| def _embed_sequence(self, max_length, sequence): | |
| encoded = self.tokenizer.encode_plus( | |
| sequence, | |
| add_special_tokens=True, | |
| truncation=True, | |
| max_length=max_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| vec = ( | |
| self.model.forward1(encoded.to(self.model.device)) | |
| .squeeze() | |
| .cpu() | |
| .numpy() | |
| ) | |
| return vec | |
| def _format_search_results(self, response): | |
| search_results = [] | |
| max_dist = math.sqrt(2 * self.n_dims) | |
| for res in response: | |
| entry = res["entity"] | |
| dist = math.sqrt(res["distance"]) | |
| entry["dist"] = dist | |
| entry["score"] = (max_dist - dist) / max_dist | |
| search_results.append(entry) | |
| return search_results | |
| def search( | |
| self, | |
| vec: List[float], | |
| n_results: int, | |
| is_peptide: bool, | |
| organism: Optional[str] = None, | |
| ): | |
| is_peptide = bool(is_peptide) | |
| filter_str = f"is_peptide == {is_peptide}" | |
| if organism is not None: | |
| filter_str += f" and organism == '{organism}'" | |
| results = self.client.search( | |
| collection_name=self.milvus_params.collection_name, | |
| data=[vec], | |
| limit=n_results, | |
| output_fields=[ | |
| "genes", | |
| "uniprot_id", | |
| "pdb_name", | |
| "chain_id", | |
| "is_peptide", | |
| "organism", | |
| ], | |
| filter=filter_str, | |
| ) | |
| return results[0] | |
| def get_organisms(self): | |
| res = self.client.query( | |
| collection_name=self.milvus_params.collection_name, | |
| output_fields=["organism"], | |
| filter="entry_id > 0", | |
| ) | |
| return res | |