File size: 2,109 Bytes
8c29280 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | import torch
from transformers import AutoTokenizer, AutoModel
def lorentz_dist(u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Computes the exact Hyperbolic distance between two batches of Lorentz vectors.
"""
# Lorentz Metric signature (- + + ...)
u_0, u_x = u[..., 0:1], u[..., 1:]
v_0, v_x = v[..., 0:1], v[..., 1:]
# Minkowski inner product
inner_product = -u_0 * v_0 + (u_x * v_x).sum(dim=-1, keepdim=True)
# Avoid numerical instability inside acosh for extremely close vectors
inner_product = torch.min(inner_product, torch.tensor(-1.0, device=u.device))
return torch.acosh(-inner_product).squeeze(-1)
def main():
model_id = "YARlabs/v5_Embedding" # Ensure you have internet connection to fetch the model, or use a local path like "." if running locally
print(f"Loading {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
model.eval()
texts = [
"What is the capital of France?",
"Paris is the capital of France.",
"Berlin is the capital of Germany."
]
print("Tokenizing texts...")
inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
print("Generating Matryoshka Lorentz Embeddings with dimension 64...")
with torch.no_grad():
lorentz_vectors = model(**inputs, target_dim=64)
print(f"Vectors shape: {lorentz_vectors.shape}")
# Calculate distances
dist_correct = lorentz_dist(lorentz_vectors[0], lorentz_vectors[1])
dist_wrong = lorentz_dist(lorentz_vectors[0], lorentz_vectors[2])
print(f"\nDistance (Question <-> Correct Answer): {dist_correct.item():.4f}")
print(f"Distance (Question <-> Wrong Answer): {dist_wrong.item():.4f}")
if dist_correct.item() < dist_wrong.item():
print("\n✅ Semantic search successfully retrieved the closest context!")
if __name__ == "__main__":
# If testing locally, you can change model_id to "."
main()
|