File size: 3,315 Bytes
933c2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import chromadb
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import StorageContext, Settings
from llama_index.core.schema import TextNode
from openai import OpenAI
from modal_client import ModalClient
from structlog import get_logger
logger = get_logger(__name__)
from typing import Any, List
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.embeddings import BaseEmbedding
class CustomEmbeddings(BaseEmbedding):
def __init__(
self,
base_url:str,
api_key:str,
model_name: str ,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._client = OpenAI(
base_url=base_url,
api_key=api_key
)
self.model_name = model_name
@classmethod
def class_name(cls) -> str:
return "custom"
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
def _get_query_embedding(self, query: str) -> List[float]:
embeddings = self._client.embeddings.create(
model=self.model_name,
input=[query]
).data[0].embedding
return embeddings
def _get_text_embedding(self, text: str) -> List[float]:
embeddings = self._client.embeddings.create(
model=self.model_name,
input=[text]
).data[0].embedding
return embeddings
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
embeddings_data = self._client.embeddings.create(
model=self.model_name,
input=texts
)
return [embedding.embedding for embedding in embeddings_data.data]
class EmbeddingService:
def __init__(self, collection_name):
config = ModalClient.embedding_config()
Settings.embed_model = CustomEmbeddings(
api_key=config.get("api_key"),
base_url=config.get("base_url"),
model_name=config.get("model"),
embed_batch_size=32
)
Settings.chunk_size = 1024
chroma_client = chromadb.EphemeralClient()
chroma_collection = chroma_client.create_collection(collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
self.storage_context = StorageContext.from_defaults(vector_store=vector_store)
def prepare_index(self,nodes):
self.index = VectorStoreIndex.from_documents(nodes, storage_context=self.storage_context)
def infer(self, query, top_k=10):
retriever = self.index.as_retriever(similarity_top_k=top_k)
results = retriever.retrieve(query)
text = ""
for result in results:
text += "\n -------------------------- \n"
text += f"name = {result.metadata['name']}\n"
text += f"filename = {result.metadata['filename']}\n"
text += f"type = {result.metadata['type']}\n"
text += f"namespace = {result.metadata['namespace']}\n"
text += f"content = {result.text}\n"
return text
|