from typing import List import torch from transformers import AutoTokenizer, AutoModel from langchain_core.embeddings import Embeddings class YarHyperbolicEmbeddings(Embeddings): """ LangChain wrapper for YAR.INK V5 Embedding Model natively operating in Hyperbolic space. WARNING: Embeddings returned are Lorentz vectors. Standard Euclidean distances (L2, Cosine Similarity) will yield mathematically incorrect results. You MUST use a Vector Store capable of supporting custom distance metrics (such as Lorentz Distance) or wrap the distance function inside your retrieval logic. """ def __init__(self, model_id: str = "YARlabs/v5_Embedding", target_dim: int = 64, device: str = None): if device is None: self.device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") else: self.device = device print(f"Loading Yar.Ink model onto {self.device}...") self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(self.device) self.model.eval() self.target_dim = target_dim def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of documents into Hyperbolic space.""" # Note: In production, batch this if texts list is large! inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device) with torch.no_grad(): vecs = self.model(**inputs, target_dim=self.target_dim) return vecs.cpu().tolist() def embed_query(self, text: str) -> List[float]: """Embed a single query into Hyperbolic space.""" return self.embed_documents([text])[0] # Example Usage if __name__ == "__main__": embeddings = YarHyperbolicEmbeddings(model_id=".", target_dim=64) # Change "." to "YARlabs/v5_Embedding" for remote fetch doc_vectors = embeddings.embed_documents([ "LangChain is a framework for developing applications powered by language models.", "Hyperbolic embeddings can represent hierarchical data effectively." ]) query_vector = embeddings.embed_query("What is LangChain?") print(f"Generated {len(doc_vectors)} document vectors of dimension {len(doc_vectors[0])} (t + spatial)") print(f"Generated query vector of dimension {len(query_vector)}")