stashface_onnx / models /data_manager.py
cc1234
Add *.proxima to root LFS tracking, improve collection load logging
bdd2b03
import os
import zvec
from typing import Dict, Any, Optional, List, Tuple
class DataManager:
def __init__(self, collection_path: str = "data/performers.zvec"):
self.collection_path = collection_path
self._metadata_cache: Dict[str, Dict] = {}
self.collection = None
self._load_collection()
def _load_collection(self):
try:
self.collection = zvec.open(
path=self.collection_path,
option=zvec.CollectionOption(read_only=True, enable_mmap=True),
)
print(f"Collection loaded OK from {self.collection_path}")
except Exception as e:
import traceback
print(f"FATAL: failed to load collection from {self.collection_path}: {e}")
traceback.print_exc()
def get_performer_info(self, stash_id: str, confidence: float, distance: float = 0.0) -> Optional[Dict[str, Any]]:
meta = self._metadata_cache.get(stash_id)
if not meta:
return None
confidence_int = int(confidence * 100)
return {
'id': stash_id,
'name': meta.get("name", ""),
'confidence': confidence_int,
'distance': round(distance, 4),
'image': meta.get("image", ""),
'country': meta.get("country") or None,
'hits': 1,
'performer_url': f"https://stashdb.org/performers/{stash_id}"
}
def _query(self, field_name: str, embedding, limit: int) -> Tuple[List[str], List[float]]:
if self.collection is None:
return [], []
results = self.collection.query(
vectors=zvec.VectorQuery(field_name=field_name, vector=embedding.tolist()),
topk=limit,
)
ids = []
distances = []
for doc in results:
doc_id = doc.id if hasattr(doc, 'id') else doc['id']
doc_score = doc.score if hasattr(doc, 'score') else doc['score']
doc_fields = doc.fields if hasattr(doc, 'fields') else doc.get('fields', {})
ids.append(doc_id)
distances.append(doc_score)
self._metadata_cache[doc_id] = doc_fields
return ids, distances
def query_facenet_index(self, embedding, limit: int) -> Tuple[List[str], List[float]]:
return self._query("facenet", embedding, limit)
def query_arc_index(self, embedding, limit: int) -> Tuple[List[str], List[float]]:
return self._query("arc", embedding, limit)
def query_multi(self, facenet_emb, arc_emb, limit: int) -> Tuple[List[str], List[float]]:
if self.collection is None:
return [], []
results = self.collection.query(
vectors=[
zvec.VectorQuery(field_name="facenet", vector=facenet_emb.tolist()),
zvec.VectorQuery(field_name="arc", vector=arc_emb.tolist()),
],
topk=limit,
reranker=zvec.WeightedReRanker(
topn=limit,
metric=zvec.MetricType.COSINE,
weights={"facenet": 1.0, "arc": 1.0},
),
)
ids = []
scores = []
for doc in results:
doc_id = doc.id if hasattr(doc, 'id') else doc['id']
doc_score = doc.score if hasattr(doc, 'score') else doc['score']
doc_fields = doc.fields if hasattr(doc, 'fields') else doc.get('fields', {})
ids.append(doc_id)
scores.append(doc_score)
self._metadata_cache[doc_id] = doc_fields
return ids, scores