| from typing import List |
| import torch |
| from transformers import AutoTokenizer, AutoModel |
| from langchain_core.embeddings import Embeddings |
|
|
| class YarHyperbolicEmbeddings(Embeddings): |
| """ |
| LangChain wrapper for YAR.INK V5 Embedding Model natively operating in Hyperbolic space. |
| |
| WARNING: Embeddings returned are Lorentz vectors. Standard Euclidean distances |
| (L2, Cosine Similarity) will yield mathematically incorrect results. |
| You MUST use a Vector Store capable of supporting custom distance metrics |
| (such as Lorentz Distance) or wrap the distance function inside your retrieval logic. |
| """ |
| def __init__(self, model_id: str = "YARlabs/v5_Embedding", target_dim: int = 64, device: str = None): |
| if device is None: |
| self.device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") |
| else: |
| self.device = device |
| |
| print(f"Loading Yar.Ink model onto {self.device}...") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
| self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(self.device) |
| self.model.eval() |
| self.target_dim = target_dim |
|
|
| def embed_documents(self, texts: List[str]) -> List[List[float]]: |
| """Embed a list of documents into Hyperbolic space.""" |
| |
| inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device) |
| |
| with torch.no_grad(): |
| vecs = self.model(**inputs, target_dim=self.target_dim) |
| |
| return vecs.cpu().tolist() |
|
|
| def embed_query(self, text: str) -> List[float]: |
| """Embed a single query into Hyperbolic space.""" |
| return self.embed_documents([text])[0] |
|
|
| |
| if __name__ == "__main__": |
| embeddings = YarHyperbolicEmbeddings(model_id=".", target_dim=64) |
| |
| doc_vectors = embeddings.embed_documents([ |
| "LangChain is a framework for developing applications powered by language models.", |
| "Hyperbolic embeddings can represent hierarchical data effectively." |
| ]) |
| query_vector = embeddings.embed_query("What is LangChain?") |
| |
| print(f"Generated {len(doc_vectors)} document vectors of dimension {len(doc_vectors[0])} (t + spatial)") |
| print(f"Generated query vector of dimension {len(query_vector)}") |
|
|