Spaces:
Runtime error
Runtime error
| import modal | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from .app import app | |
| from .image import image | |
| from .volume import volume | |
| class TaskModelRetrieverModalApp: | |
| def setup(self): | |
| tasks = ["object-detection", "image-segmentation"] | |
| self.vector_stores = {} | |
| for task in tasks: | |
| self.vector_stores[task] = FAISS.load_local( | |
| folder_path=f"/volume/vector_store/{task}", | |
| embeddings=HuggingFaceEmbeddings( | |
| model_name="all-MiniLM-L6-v2", | |
| model_kwargs={"device": "cuda"}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| show_progress=True, | |
| ), | |
| index_name="faiss_index", | |
| allow_dangerous_deserialization=True, | |
| ) | |
| def forward(self, task: str, query: str) -> str: | |
| docs = self.vector_stores[task].similarity_search(query, k=7) | |
| model_ids = [doc.metadata["model_id"] for doc in docs] | |
| model_labels = [doc.metadata["model_labels"] for doc in docs] | |
| models_dict = {model_id: model_labels for model_id, model_labels in zip(model_ids, model_labels)} | |
| return models_dict | |
| def object_detection_search(self, query: str) -> str: | |
| return self.forward("object-detection", query) | |
| def image_segmentation_search(self, query: str) -> str: | |
| return self.forward("image-segmentation", query) | |