Spaces:
Runtime error
Runtime error
| from overrides import EnforceOverrides, override | |
| from typing import List, Optional, Sequence | |
| from chromadb.config import System | |
| from chromadb.proto.convert import ( | |
| from_proto_vector_embedding_record, | |
| from_proto_vector_query_result, | |
| to_proto_vector, | |
| ) | |
| from chromadb.segment import VectorReader | |
| from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams | |
| from chromadb.telemetry.opentelemetry import ( | |
| OpenTelemetryClient, | |
| OpenTelemetryGranularity, | |
| trace_method, | |
| ) | |
| from chromadb.types import ( | |
| Metadata, | |
| ScalarEncoding, | |
| Segment, | |
| VectorEmbeddingRecord, | |
| VectorQuery, | |
| VectorQueryResult, | |
| ) | |
| from chromadb.proto.chroma_pb2_grpc import VectorReaderStub | |
| from chromadb.proto.chroma_pb2 import ( | |
| GetVectorsRequest, | |
| GetVectorsResponse, | |
| QueryVectorsRequest, | |
| QueryVectorsResponse, | |
| ) | |
| import grpc | |
| class GrpcVectorSegment(VectorReader, EnforceOverrides): | |
| _vector_reader_stub: VectorReaderStub | |
| _segment: Segment | |
| _opentelemetry_client: OpenTelemetryClient | |
| def __init__(self, system: System, segment: Segment): | |
| # TODO: move to start() method | |
| # TODO: close channel in stop() method | |
| if segment["metadata"] is None or segment["metadata"]["grpc_url"] is None: | |
| raise Exception("Missing grpc_url in segment metadata") | |
| channel = grpc.insecure_channel(segment["metadata"]["grpc_url"]) | |
| self._vector_reader_stub = VectorReaderStub(channel) # type: ignore | |
| self._segment = segment | |
| self._opentelemetry_client = system.require(OpenTelemetryClient) | |
| def get_vectors( | |
| self, ids: Optional[Sequence[str]] = None | |
| ) -> Sequence[VectorEmbeddingRecord]: | |
| request = GetVectorsRequest(ids=ids, segment_id=self._segment["id"].hex) | |
| response: GetVectorsResponse = self._vector_reader_stub.GetVectors(request) | |
| results: List[VectorEmbeddingRecord] = [] | |
| for vector in response.records: | |
| result = from_proto_vector_embedding_record(vector) | |
| results.append(result) | |
| return results | |
| def query_vectors( | |
| self, query: VectorQuery | |
| ) -> Sequence[Sequence[VectorQueryResult]]: | |
| request = QueryVectorsRequest( | |
| vectors=[ | |
| to_proto_vector(vector=v, encoding=ScalarEncoding.FLOAT32) | |
| for v in query["vectors"] | |
| ], | |
| k=query["k"], | |
| allowed_ids=query["allowed_ids"], | |
| include_embeddings=query["include_embeddings"], | |
| segment_id=self._segment["id"].hex, | |
| ) | |
| response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors(request) | |
| results: List[List[VectorQueryResult]] = [] | |
| for result in response.results: | |
| curr_result: List[VectorQueryResult] = [] | |
| for r in result.results: | |
| curr_result.append(from_proto_vector_query_result(r)) | |
| results.append(curr_result) | |
| return results | |
| def count(self) -> int: | |
| raise NotImplementedError() | |
| def max_seqid(self) -> int: | |
| return 0 | |
| def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]: | |
| # Great example of why language sharing is nice. | |
| segment_metadata = PersistentHnswParams.extract(metadata) | |
| return segment_metadata | |
| def delete(self) -> None: | |
| raise NotImplementedError() | |