from typing import Union, List from langchain.embeddings import HuggingFaceEmbeddings class EmbeddingModel: """ A flexible embedding model wrapper supporting multiple backend models. """ def __init__(self, model_type: str = "huggingface", model_name: str = "all-MiniLM-L6-v2"): self.model_type = model_type self.model_name = model_name self.model = self._load_model() def _load_model(self): if self.model_type == "huggingface": return HuggingFaceEmbeddings(model_name=self.model_name) # Implementation for other model types can be added here else: raise ValueError(f"Unsupported model type: {self.model_type}") def embed(self, text: Union[str, List[str]]): """ Generate embeddings for the given text. :param text: A string or list of strings. :return: A list of embeddings. """ if self.model_type == "huggingface": if isinstance(text, list): return [self.model.embed_query(t) for t in text] return self.model.embed_query(text) elif self.model_type == "sentence_transformers": return self.model.encode(text, convert_to_tensor=True).tolist() else: raise NotImplementedError(f"Embedding for {self.model_type} is not implemented.")