Gokul Soumya
feat: Implement binary shield as a library
8972ad7
from sentence_transformers import SentenceTransformer
from dataclasses import dataclass
from binary_shield.comparison import compute_similarity, hamming_distance
from binary_shield.embedding import extract_embedding
from binary_shield.privacy import apply_randomized_response
from binary_shield.quantization import BinaryPackedEmbedding, binary_quantize
@dataclass
class BinaryFingerprint:
fingerprint: BinaryPackedEmbedding
epsilon: float | None
@dataclass
class ComparisonResult:
hamming_distance: int
similarity: float
is_match: bool
class BinaryShield:
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
epsilon: float | None = None,
) -> None:
self.model = SentenceTransformer(model_name)
self.epsilon = epsilon
def generate_fingerprint(self, text: str) -> BinaryFingerprint:
embedding = extract_embedding(text, self.model)
bin_embedding = binary_quantize(embedding)
if self.epsilon is not None:
bin_embedding = apply_randomized_response(bin_embedding, self.epsilon)
return BinaryFingerprint(
fingerprint=bin_embedding,
epsilon=self.epsilon,
)
@staticmethod
def compare(
fp1: BinaryFingerprint,
fp2: BinaryFingerprint,
threshold: float = 0.8,
) -> ComparisonResult:
dist = hamming_distance(fp1.fingerprint, fp2.fingerprint)
sim = compute_similarity(fp1.fingerprint, fp2.fingerprint)
return ComparisonResult(
hamming_distance=dist,
similarity=sim,
is_match=sim >= threshold,
)