File size: 650 Bytes
09dc9d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from rag_pipeline import get_embeddings, vretrieve, rerank
from utils import load_local

import argparse

def inference():
    embed_model = get_embeddings(args.embed_model_name)
    vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
    retrieve_results = vretrieve(args.query, vectorstore, docs, args.retriever_k, args.metric, args.threshold)
    
    retrieve_results = rerank(retrieve_results)

    print(retrieve_results)

def conversation():
    while True:
        query = input("User: ")
        if query == "exit":
            break
        inference(query)

if __name__ == '__main__':
    conversation()