| from pymilvus import MilvusClient, WeightedRanker, AnnSearchRequest |
| from langchain_ollama import OllamaEmbeddings |
|
|
| class MilvusRetriever: |
| def __init__(self, uri): |
| self.uri = uri |
| self.embed_model = OllamaEmbeddings(model="bge-m3") |
| self.client = MilvusClient(self.uri) |
|
|
| def search(self, query, collection_name, top_k=10): |
| |
| |
| """_summary_ |
| |
| Args: |
| query (_type_): query string |
| collection_name (_type_): milvus_collection_name |
| top_k (int, optional): Top k results. Defaults to 10. |
| |
| Returns: |
| [{"id", "distance", "entity"}] |
| """ |
| query_embedding = self.embed_model.embed_query(query) |
| |
| if collection_name == "t_sur_sex_ed_article_spider": |
| return self.article_search(query_embedding, collection_name, top_k=top_k) |
| |
| if collection_name == "t_sur_sex_ed_question_answer_spider": |
| return self.qa_search(query_embedding, collection_name, top_k=top_k) |
| |
| if collection_name == "t_sur_sex_ed_youtube_spider": |
| return self.video_search(query_embedding, collection_name, top_k=top_k) |
| |
| def article_search(self, embedding, collection_name, top_k): |
| search_param1 = { |
| "data": [embedding], |
| "anns_field": "chunk_vector", |
| "param": { |
| "metric_type": "COSINE", |
| "params": {"nprobe": 10} |
| }, |
| "limit": top_k |
| } |
| search_param2 = { |
| "data": [embedding], |
| "anns_field": "title_vector", |
| "param": { |
| "metric_type": "COSINE", |
| "params": {"nprobe": 10} |
| }, |
| "limit": top_k |
| } |
| search_param3 = { |
| "data": [embedding], |
| "anns_field": "tags", |
| "param": { |
| "metric_type": "COSINE", |
| "params": {"nprobe": 10} |
| }, |
| "limit": top_k |
| } |
| rerank = WeightedRanker(0.6, 0.3, 0.1) |
| r1, r2, r3 = AnnSearchRequest(**search_param1), AnnSearchRequest(**search_param2), AnnSearchRequest(**search_param3) |
| candidates = [r1, r2, r3] |
| res = self.client.hybrid_search( |
| collection_name=collection_name, |
| ranker=rerank, |
| reqs=candidates, |
| limit=top_k, |
| output_fields=["title", "link", "chunk", "category"] |
| )[0] |
| return res |
|
|
| def qa_search(self, embedding, collection_name, top_k): |
| res = self.client.search( |
| collection_name=collection_name, |
| data=[embedding], |
| anns_field="title_vector", |
| search_params={"metric_type": "COSINE", "params": {"nprobe": 10}}, |
| limit=top_k, |
| filter="content_type == 'A'", |
| output_fields=["title", "content", "url", "author", "avatar_url", "likes", "dislikes"] |
| )[0] |
| |
| titles = [] |
| result = [] |
| for record in res: |
| if record["entity"]["title"] not in titles: |
| titles.append(record["entity"]["title"]) |
| result.append(record) |
| return result |
| |
| def video_search(self, embedding, collection_name, top_k): |
| res = self.client.search( |
| collection_name=collection_name, |
| data=[embedding], |
| anns_field="title_vector", |
| search_params={"metric_type": "COSINE", "params": {"nprobe": 10}}, |
| filter="delete_status == 0", |
| limit=top_k, |
| output_fields=["title", "link", "author", "picture", "duration"] |
| )[0] |
| return res |
|
|
| def porn_search(self, embedding, collection_name, top_k): |
| pass |
|
|
| if __name__ == "__main__": |
| import json |
| retriever = MilvusRetriever(uri="http://localhost:19530") |
| colleciton_name = "t_sur_sex_ed_article_spider" |
| query = "How to build trust?" |
| res = retriever.search(query, colleciton_name, top_k=10) |
| res = [record["entity"] for record in res if record["distance"] > 0.3] |
| print(json.dumps(res)) |
|
|