import chromadb from llama_index.core import VectorStoreIndex from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.core import StorageContext, Settings from llama_index.core.schema import TextNode from openai import OpenAI from modal_client import ModalClient from structlog import get_logger logger = get_logger(__name__) from typing import Any, List from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.embeddings import BaseEmbedding class CustomEmbeddings(BaseEmbedding): def __init__( self, base_url:str, api_key:str, model_name: str , **kwargs: Any, ) -> None: super().__init__(**kwargs) self._client = OpenAI( base_url=base_url, api_key=api_key ) self.model_name = model_name @classmethod def class_name(cls) -> str: return "custom" async def _aget_query_embedding(self, query: str) -> List[float]: return self._get_query_embedding(query) async def _aget_text_embedding(self, text: str) -> List[float]: return self._get_text_embedding(text) def _get_query_embedding(self, query: str) -> List[float]: embeddings = self._client.embeddings.create( model=self.model_name, input=[query] ).data[0].embedding return embeddings def _get_text_embedding(self, text: str) -> List[float]: embeddings = self._client.embeddings.create( model=self.model_name, input=[text] ).data[0].embedding return embeddings def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: embeddings_data = self._client.embeddings.create( model=self.model_name, input=texts ) return [embedding.embedding for embedding in embeddings_data.data] class EmbeddingService: def __init__(self, collection_name): config = ModalClient.embedding_config() Settings.embed_model = CustomEmbeddings( api_key=config.get("api_key"), base_url=config.get("base_url"), model_name=config.get("model"), embed_batch_size=32 ) Settings.chunk_size = 1024 chroma_client = chromadb.EphemeralClient() chroma_collection = chroma_client.create_collection(collection_name) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) self.storage_context = StorageContext.from_defaults(vector_store=vector_store) def prepare_index(self,nodes): self.index = VectorStoreIndex.from_documents(nodes, storage_context=self.storage_context) def infer(self, query, top_k=10): retriever = self.index.as_retriever(similarity_top_k=top_k) results = retriever.retrieve(query) text = "" for result in results: text += "\n -------------------------- \n" text += f"name = {result.metadata['name']}\n" text += f"filename = {result.metadata['filename']}\n" text += f"type = {result.metadata['type']}\n" text += f"namespace = {result.metadata['namespace']}\n" text += f"content = {result.text}\n" return text