Spaces:
Sleeping
Sleeping
| from pymilvus import MilvusClient | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict, Any, Optional, Union | |
| import logging | |
| from app.config import MILVUS_DB_URL, MILVUS_DB_TOKEN, EMBEDDING_MODEL, DATASET_ID | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class Database: | |
| """数据库操作类,处理与Milvus的交互""" | |
| def __init__(self): | |
| self.client = MilvusClient( | |
| uri = MILVUS_DB_URL, | |
| token= MILVUS_DB_TOKEN) | |
| self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True) | |
| print('初始化模型完成',self.model) | |
| self.collection_name = "stickers" | |
| def init_collection(self) -> bool: | |
| """初始化 Milvus 数据库""" | |
| try: | |
| print('初始化 Milvus 数据库', self.client.list_collections()) | |
| if not len(self.client.list_collections()) > 0: | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| dimension=768, | |
| primary_field="id", | |
| auto_id=True | |
| ) | |
| self.client.create_index( | |
| collection_name=self.collection_name, | |
| index_type="IVF_SQ8", | |
| metric_type="COSINE", | |
| params={"nlist": 128}, | |
| index_params={} | |
| ) | |
| logger.info(f"Collection initialized: {self.collection_name}") | |
| print('初始化 Milvus 数据库成功', self.client.list_collections()) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Collection initialization failed: {str(e)}") | |
| return False | |
| def encode_text(self, text: str) -> List[float]: | |
| """将文本编码为向量""" | |
| return self.model.encode(text).tolist() | |
| def store_sticker(self, title: str, description: str, tags: Union[str, List[str]], file_path: str, image_hash: str = None) -> bool: | |
| """存储贴纸数据到Milvus""" | |
| try: | |
| vector = self.encode_text(description) | |
| # 处理标签格式 | |
| if isinstance(tags, str): | |
| tags = tags.split(",") | |
| logger.info(f"Storing to Milvus - title: {title}, description: {description}, file_path: {file_path}, tags: {tags}, image_hash: {image_hash}") | |
| self.client.insert( | |
| collection_name=self.collection_name, | |
| data=[{ | |
| "vector": vector, | |
| "title": title, | |
| "description": description, | |
| "tags": tags, | |
| "file_name": file_path, | |
| "image_hash": image_hash | |
| }] | |
| ) | |
| logger.info("Storing to Milvus Success ✅") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to store sticker: {str(e)}") | |
| return False | |
| def search_stickers(self, description: str, limit: int = 2) -> List[Dict[str, Any]]: | |
| """搜索贴纸""" | |
| if not description: | |
| return [] | |
| try: | |
| text_vector = self.encode_text(description) | |
| logger.info(f"Searching Milvus - query: {description}, limit: {limit}") | |
| results = self.client.search( | |
| collection_name=self.collection_name, | |
| data=[text_vector], | |
| limit=limit, | |
| search_params={ | |
| "metric_type": "COSINE", | |
| }, | |
| output_fields=["title", "description", "tags", "file_name"], | |
| ) | |
| logger.info(f"Search Result: {results}") | |
| return results[0] | |
| except Exception as e: | |
| logger.error(f"Search failed: {str(e)}") | |
| return [] | |
| def get_all_stickers(self, limit: int = 1000) -> List[Dict[str, Any]]: | |
| """获取所有贴纸""" | |
| try: | |
| results = self.client.query( | |
| collection_name=self.collection_name, | |
| filter="", | |
| limit=limit, | |
| output_fields=["title", "description", "tags", "file_name", "image_hash"] | |
| ) | |
| logger.info(f"Query All Stickers - limit: {limit}, results count: {len(results)}") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Failed to get all stickers: {str(e)}") | |
| return [] | |
| def check_image_exists(self, image_hash: str) -> bool: | |
| """检查文件名是否已存在""" | |
| try: | |
| results = self.client.query( | |
| collection_name=self.collection_name, | |
| filter=f"image_hash == '{image_hash}'", | |
| limit=1, | |
| output_fields=["file_name", "image_hash"] | |
| ) | |
| exists = len(results) > 0 | |
| return exists | |
| except Exception as e: | |
| logger.error(f"Failed to check file exists: {str(e)}") | |
| return False | |
| def delete_sticker(self, sticker_id: int) -> str: | |
| """删除贴纸""" | |
| try: | |
| logger.info(f"Deleting sticker - id: {sticker_id}") | |
| res = self.client.delete( | |
| collection_name=self.collection_name, | |
| ids=[sticker_id] | |
| ) | |
| logger.info(f"Deleted sticker - id: {sticker_id}") | |
| print(res) | |
| return f"Sticker with ID {sticker_id} deleted successfully" | |
| except Exception as e: | |
| logger.error(f"Failed to delete sticker: {str(e)}") | |
| return f"Failed to delete sticker: {str(e)}" | |
| def batch_store_stickers(self, stickers: List[Dict[str, Any]], batch_size: int = 100) -> bool: | |
| """批量存储贴纸数据到Milvus | |
| Args: | |
| stickers (List[Dict[str, Any]]): 贴纸数据列表,每个元素包含以下字段: | |
| - title: str | |
| - description: str | |
| - tags: Union[str, List[str]] | |
| - file_path: str | |
| - image_hash: str (可选) | |
| batch_size (int, optional): 每批处理的数量. Defaults to 100. | |
| Returns: | |
| bool: 是否全部插入成功 | |
| """ | |
| try: | |
| total_stickers = len(stickers) | |
| if total_stickers == 0: | |
| logger.warning("No stickers to store") | |
| return True | |
| logger.info(f"Starting batch store of {total_stickers} stickers") | |
| # 分批处理 | |
| for i in range(0, total_stickers, batch_size): | |
| batch = stickers[i:i + batch_size] | |
| batch_data = [] | |
| for sticker in batch: | |
| # 处理标签格式 | |
| tags = sticker.get("tags", []) | |
| if isinstance(tags, str): | |
| tags = tags.split(",") | |
| # 编码描述文本 | |
| vector = self.encode_text(sticker.get("description", "")) | |
| batch_data.append({ | |
| "vector": vector, | |
| "title": sticker.get("title", ""), | |
| "description": sticker.get("description", ""), | |
| "tags": tags, | |
| "file_name": sticker.get("file_path", ""), | |
| "image_hash": sticker.get("image_hash") | |
| }) | |
| # 批量插入 | |
| if batch_data: | |
| self.client.insert( | |
| collection_name=self.collection_name, | |
| data=batch_data | |
| ) | |
| logger.info(f"Batch {i//batch_size + 1} stored successfully - {len(batch_data)} stickers") | |
| logger.info("All stickers stored successfully ✅") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to batch store stickers: {str(e)}") | |
| return False | |
| # 初始化 Milvus 数据库 | |
| # 创建数据库实例 | |
| db = Database() |