Spaces:
Runtime error
Runtime error
File size: 2,089 Bytes
3fe04a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | from flask import Flask, request, jsonify
import json
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from your_rag_module import lcpp_llm # your LLM wrapper
app = Flask(__name__)
# Load prompt templates
with open("prompt_config.json") as f:
prompt_config = json.load(f)
qna_system_message = prompt_config["system_message"]
qna_user_message_template = prompt_config["user_template"]
# Load retriever config and initialize retriever
with open("retriever_config.json") as f:
retriever_config = json.load(f)
embedding_function = SentenceTransformerEmbeddings(model_name=retriever_config["embedding_model"])
retriever = Chroma(
persist_directory=retriever_config["persist_directory"],
embedding_function=embedding_function
).as_retriever()
def generate_rag_response(user_input, k=3, max_tokens=128, temperature=0, top_p=0.95, top_k=50):
relevant_document_chunks = retriever.get_relevant_documents(query=user_input, k=k)
context_list = [d.page_content for d in relevant_document_chunks]
context_for_query = ". ".join(context_list)
user_message = qna_user_message_template.replace('{context}', context_for_query)
user_message = user_message.replace('{question}', user_input)
prompt = qna_system_message + '\n' + user_message
try:
response = lcpp_llm(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k
)
response = response['choices'][0]['text'].strip()
except Exception as e:
response = f'Sorry, I encountered the following error: \n {e}'
return response
@app.route("/v1/query", methods=["POST"])
def query():
user_input = request.json.get("query", "")
response = generate_rag_response(user_input)
return jsonify({"response": response})
@app.route("/ping", methods=["GET"])
def health():
return "Backend is alive!", 200
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)
|