File size: 2,089 Bytes
3fe04a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from flask import Flask, request, jsonify
import json
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from your_rag_module import lcpp_llm  # your LLM wrapper

app = Flask(__name__)

# Load prompt templates
with open("prompt_config.json") as f:
    prompt_config = json.load(f)

qna_system_message = prompt_config["system_message"]
qna_user_message_template = prompt_config["user_template"]

# Load retriever config and initialize retriever
with open("retriever_config.json") as f:
    retriever_config = json.load(f)

embedding_function = SentenceTransformerEmbeddings(model_name=retriever_config["embedding_model"])
retriever = Chroma(
    persist_directory=retriever_config["persist_directory"],
    embedding_function=embedding_function
).as_retriever()

def generate_rag_response(user_input, k=3, max_tokens=128, temperature=0, top_p=0.95, top_k=50):
    relevant_document_chunks = retriever.get_relevant_documents(query=user_input, k=k)
    context_list = [d.page_content for d in relevant_document_chunks]
    context_for_query = ". ".join(context_list)

    user_message = qna_user_message_template.replace('{context}', context_for_query)
    user_message = user_message.replace('{question}', user_input)
    prompt = qna_system_message + '\n' + user_message

    try:
        response = lcpp_llm(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k
        )
        response = response['choices'][0]['text'].strip()
    except Exception as e:
        response = f'Sorry, I encountered the following error: \n {e}'

    return response

@app.route("/v1/query", methods=["POST"])
def query():
    user_input = request.json.get("query", "")
    response = generate_rag_response(user_input)
    return jsonify({"response": response})

@app.route("/ping", methods=["GET"])
def health():
    return "Backend is alive!", 200

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)