Spaces:
Running
Running
| import os | |
| import zvec | |
| from typing import Dict, Any, Optional, List, Tuple | |
| class DataManager: | |
| def __init__(self, collection_path: str = "data/performers.zvec"): | |
| self.collection_path = collection_path | |
| self._metadata_cache: Dict[str, Dict] = {} | |
| self.collection = None | |
| self._load_collection() | |
| def _load_collection(self): | |
| try: | |
| self.collection = zvec.open( | |
| path=self.collection_path, | |
| option=zvec.CollectionOption(read_only=True, enable_mmap=True), | |
| ) | |
| print(f"Collection loaded OK from {self.collection_path}") | |
| except Exception as e: | |
| import traceback | |
| print(f"FATAL: failed to load collection from {self.collection_path}: {e}") | |
| traceback.print_exc() | |
| def get_performer_info(self, stash_id: str, confidence: float, distance: float = 0.0) -> Optional[Dict[str, Any]]: | |
| meta = self._metadata_cache.get(stash_id) | |
| if not meta: | |
| return None | |
| confidence_int = int(confidence * 100) | |
| return { | |
| 'id': stash_id, | |
| 'name': meta.get("name", ""), | |
| 'confidence': confidence_int, | |
| 'distance': round(distance, 4), | |
| 'image': meta.get("image", ""), | |
| 'country': meta.get("country") or None, | |
| 'hits': 1, | |
| 'performer_url': f"https://stashdb.org/performers/{stash_id}" | |
| } | |
| def _query(self, field_name: str, embedding, limit: int) -> Tuple[List[str], List[float]]: | |
| if self.collection is None: | |
| return [], [] | |
| results = self.collection.query( | |
| vectors=zvec.VectorQuery(field_name=field_name, vector=embedding.tolist()), | |
| topk=limit, | |
| ) | |
| ids = [] | |
| distances = [] | |
| for doc in results: | |
| doc_id = doc.id if hasattr(doc, 'id') else doc['id'] | |
| doc_score = doc.score if hasattr(doc, 'score') else doc['score'] | |
| doc_fields = doc.fields if hasattr(doc, 'fields') else doc.get('fields', {}) | |
| ids.append(doc_id) | |
| distances.append(doc_score) | |
| self._metadata_cache[doc_id] = doc_fields | |
| return ids, distances | |
| def query_facenet_index(self, embedding, limit: int) -> Tuple[List[str], List[float]]: | |
| return self._query("facenet", embedding, limit) | |
| def query_arc_index(self, embedding, limit: int) -> Tuple[List[str], List[float]]: | |
| return self._query("arc", embedding, limit) | |
| def query_multi(self, facenet_emb, arc_emb, limit: int) -> Tuple[List[str], List[float]]: | |
| if self.collection is None: | |
| return [], [] | |
| results = self.collection.query( | |
| vectors=[ | |
| zvec.VectorQuery(field_name="facenet", vector=facenet_emb.tolist()), | |
| zvec.VectorQuery(field_name="arc", vector=arc_emb.tolist()), | |
| ], | |
| topk=limit, | |
| reranker=zvec.WeightedReRanker( | |
| topn=limit, | |
| metric=zvec.MetricType.COSINE, | |
| weights={"facenet": 1.0, "arc": 1.0}, | |
| ), | |
| ) | |
| ids = [] | |
| scores = [] | |
| for doc in results: | |
| doc_id = doc.id if hasattr(doc, 'id') else doc['id'] | |
| doc_score = doc.score if hasattr(doc, 'score') else doc['score'] | |
| doc_fields = doc.fields if hasattr(doc, 'fields') else doc.get('fields', {}) | |
| ids.append(doc_id) | |
| scores.append(doc_score) | |
| self._metadata_cache[doc_id] = doc_fields | |
| return ids, scores | |