Spaces:
Sleeping
Sleeping
| from typing import Union, List | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| class EmbeddingModel: | |
| """ | |
| A flexible embedding model wrapper supporting multiple backend models. | |
| """ | |
| def __init__(self, model_type: str = "huggingface", model_name: str = "all-MiniLM-L6-v2"): | |
| self.model_type = model_type | |
| self.model_name = model_name | |
| self.model = self._load_model() | |
| def _load_model(self): | |
| if self.model_type == "huggingface": | |
| return HuggingFaceEmbeddings(model_name=self.model_name) | |
| # Implementation for other model types can be added here | |
| else: | |
| raise ValueError(f"Unsupported model type: {self.model_type}") | |
| def embed(self, text: Union[str, List[str]]): | |
| """ | |
| Generate embeddings for the given text. | |
| :param text: A string or list of strings. | |
| :return: A list of embeddings. | |
| """ | |
| if self.model_type == "huggingface": | |
| if isinstance(text, list): | |
| return [self.model.embed_query(t) for t in text] | |
| return self.model.embed_query(text) | |
| elif self.model_type == "sentence_transformers": | |
| return self.model.encode(text, convert_to_tensor=True).tolist() | |
| else: | |
| raise NotImplementedError(f"Embedding for {self.model_type} is not implemented.") | |