|
|
from pymilvus import MilvusClient |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
import json |
|
|
import os |
|
|
from pymilvus import Collection, connections |
|
|
from openai import OpenAI |
|
|
from typing import List |
|
|
|
|
|
|
|
|
def generate_embedding_api(client: OpenAI, text: str, model: str) -> List[float]: |
|
|
""" |
|
|
Generate an embedding vector for the given text using a specified model from OpenAI. |
|
|
|
|
|
Args: |
|
|
client (OpenAI): The OpenAI client instance used to generate embeddings. |
|
|
text (str): The text input for which to generate an embedding. |
|
|
model (str): The model name to use for generating the embedding. |
|
|
|
|
|
Returns: |
|
|
list: A list representing the embedding vector for the given text. |
|
|
""" |
|
|
response = client.embeddings.create(input=text, model=model) |
|
|
return response.data[0].embedding |
|
|
|
|
|
|
|
|
def load_embedding(index: int, embedding_path: str) -> np.ndarray: |
|
|
""" |
|
|
Load an embedding from a memory-mapped file. |
|
|
|
|
|
Args: |
|
|
index (int): The index of the embedding to load. |
|
|
embedding_path (str): The path to the memory-mapped embedding file. |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: The embedding vector located at the specified index. |
|
|
""" |
|
|
memmap_array = np.lib.format.open_memmap(embedding_path, mode="r") |
|
|
row = memmap_array[index] |
|
|
return row |
|
|
|
|
|
|
|
|
def export_collection( |
|
|
db_client: MilvusClient, collection_name: str, file_name: str, params: dict |
|
|
): |
|
|
""" |
|
|
Exports data from a text file into a Milvus collection, associating each text entry |
|
|
with a pre-generated embedding. |
|
|
|
|
|
Args: |
|
|
db_client (MilvusClient): The Milvus client instance used to interact with the database. |
|
|
collection_name (str): The name of the collection in Milvus where data will be stored. |
|
|
file_name (str): The file containing plain text data to be inserted into the collection. |
|
|
params (dict): Configuration parameters, including `embedding_size` and `metric_type`. |
|
|
|
|
|
Steps: |
|
|
1. Creates the collection in Milvus using the specified parameters. |
|
|
2. Loads each embedding from the embedding file. |
|
|
3. Inserts data into Milvus with ID, vector (embedding), and plain text. |
|
|
4. Connects to the database, flushes the collection to ensure all data is stored. |
|
|
""" |
|
|
|
|
|
directory, name = os.path.split(file_name) |
|
|
new_directory = directory.replace("plain_text", "embeddings") |
|
|
base_name, _ = os.path.splitext(name) |
|
|
embedding_file_name = base_name + "_embeddings.npy" |
|
|
embedding_path = os.path.join(new_directory, embedding_file_name) |
|
|
|
|
|
|
|
|
db_client.drop_collection(collection_name) |
|
|
db_client.create_collection( |
|
|
collection_name=collection_name, |
|
|
dimension=params["embedding_size"], |
|
|
metric_type=params["metric_type"], |
|
|
) |
|
|
|
|
|
|
|
|
with open(file_name, "r") as f: |
|
|
for index, line in tqdm( |
|
|
enumerate(f), desc=f"Inserting data into {collection_name}" |
|
|
): |
|
|
embedding = load_embedding(index, embedding_path) |
|
|
data = [{"id": index, "vector": embedding, "plain_text": line.strip()}] |
|
|
db_client.insert(collection_name=collection_name, data=data) |
|
|
|
|
|
|
|
|
connections.connect(alias="default", uri="db/gars.db") |
|
|
collection = Collection(collection_name) |
|
|
collection.flush() |
|
|
|
|
|
|
|
|
def load_db(): |
|
|
""" |
|
|
Initializes the Milvus database by creating collections and populating them with data |
|
|
and pre-generated embeddings. |
|
|
|
|
|
Steps: |
|
|
1. Load configuration parameters for database and collections. |
|
|
2. Initialize Milvus client and retrieve collection names. |
|
|
3. For each collection, call `export_collection` to populate with embeddings. |
|
|
""" |
|
|
|
|
|
params = json.load(open(os.path.join("..", "config", "db_config.json"))) |
|
|
|
|
|
|
|
|
db_client = MilvusClient("db/gars.db") |
|
|
collection_names = list(set(params["prompt_elements"])) |
|
|
|
|
|
|
|
|
for collection_name in collection_names: |
|
|
print(f"Loading {collection_name} to database") |
|
|
prompt_file_name = os.path.join( |
|
|
"..", |
|
|
"resources", |
|
|
"prompt_categories", |
|
|
"plain_text", |
|
|
collection_name + ".txt", |
|
|
) |
|
|
export_collection(db_client, collection_name, prompt_file_name, params) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
load_db() |
|
|
|