File size: 6,172 Bytes
e708e53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27c0dc
e708e53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
from flask import Flask, request, jsonify
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
import warnings
import logging

# Suppress DeprecationWarnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Set the logging level for the 'llama_cpp' logger to suppress informational messages
logging.getLogger('llama_cpp').setLevel(logging.WARNING)

app = Flask(__name__)

# --- Configuration ---
# Define the directory where the vector database is persisted
VECTOR_DB_DIRECTORY = "merck_manuals" # Relative path from backend_files

# Define the names/paths of the embedding model and LLM
EMBEDDING_MODEL_NAME = "thenlper/gte-large"
LLM_MODEL_NAME_OR_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
LLM_MODEL_BASENAME = "mistral-7b-instruct-v0.2.Q2_K.gguf"
LLM_N_CTX = 2300 # Context window size, adjust based on your LLM and needs
LLM_N_GPU_LAYERS = -1 # -1 to offload all layers to GPU if available, 0 for CPU
LLM_N_BATCH = 512
LLM_N_THREADS = 4

# RAG Parameters (from Fine-Tuning Version 2 - Context-Rich + Safer LLM)
RAG_K = 5
RAG_TEMPERATURE = 0.3
RAG_TOP_P = 0.95
RAG_TOP_K = 60
RAG_MAX_TOKENS = 200

# --- Load RAG Components ---
# Load the embedding model
try:
    embedding_model = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL_NAME)
    print("Embedding model loaded successfully.")
except Exception as e:
    print(f"Error loading embedding model: {e}")
    embedding_model = None

# Load the vector database
try:
    # Ensure the vector database directory exists and is accessible
    if not os.path.exists(VECTOR_DB_DIRECTORY):
        print(f"Error: Vector database directory '{VECTOR_DB_DIRECTORY}' not found.")
        vectorstore = None
    else:
        vectorstore = Chroma(persist_directory=VECTOR_DB_DIRECTORY, embedding_function=embedding_model)
        print("Vector database loaded successfully.")
except Exception as e:
    print(f"Error loading vector database: {e}")
    vectorstore = None

# Create the retriever instance
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": RAG_K}) if vectorstore else None
print(f"Retriever created with k={RAG_K}.")

# Load the LLM
llm = None
try:
    print(f"Attempting to download and load LLM: {LLM_MODEL_BASENAME}")
    model_path = hf_hub_download(repo_id=LLM_MODEL_NAME_OR_PATH, filename=LLM_MODEL_BASENAME)
    llm = Llama(
        model_path=model_path,
        n_ctx=LLM_N_CTX,
        n_gpu_layers=LLM_N_GPU_LAYERS,
        n_batch=LLM_N_BATCH,
        n_threads=LLM_N_THREADS,
        verbose=False # Set to False for deployment
    )
    print("LLM loaded successfully.")
except Exception as e:
    print(f"Error loading LLM: {e}")
    llm = None # Ensure llm is None if loading fails

# --- RAG Logic (Adapted from notebook) ---
qna_system_message = """
    You are a knowledgeable and reliable medical assistant.
    Your role is to provide accurate, concise, and up-to-date medical information
    based on trusted medical manuals and references.
    User will have the context required by you to answer the question.
    This context will always begin with the token ###Context
    This context will contain specific portions of the relevant document to answer the question.
    Always clarify if a condition requires a medical professional’s attention.
    If you don't know the answer, just say that you don't know.
    Don't try to make up an answer.
    Answer only based on your knowledge from medical manuals. Do not hallucinate or make up facts."
"""

qna_user_message_template = """
    ### Context
    Here are some documents that are relevant to the question mentioned below:
    {context}
    ### Question
    {question}"""


def generate_rag_response(user_input):
    if llm is None or retriever is None:
        return "Error: RAG components not loaded correctly."

    # Retrieve relevant document chunks
    try:
        relevant_document_chunks = retriever.get_relevant_documents(query=user_input) # k is set in retriever
        context_list = [d.page_content for d in relevant_document_chunks]
        context_for_query = ". ".join(context_list)
    except Exception as e:
        print(f"Error during document retrieval: {e}")
        return "Error: Could not retrieve relevant information."

    user_message = qna_user_message_template.replace('{context}', context_for_query)
    user_message = user_message.replace('{question}', user_input)

    prompt = qna_system_message + '\n' + user_message

    # Generate the response
    try:
        response = llm(
                  prompt=prompt,
                  max_tokens=RAG_MAX_TOKENS,
                  temperature=RAG_TEMPERATURE,
                  top_p=RAG_TOP_P,
                  top_k=RAG_TOP_K,
                  stop=['### Context'] # Add stop token if applicable
                  )

        # Extract and return the model's response, stripping leading/trailing whitespace
        return response['choices'][0]['text'].strip()
    except Exception as e:
        print(f"Error during LLM generation: {e}")
        return f'Sorry, I encountered an error generating the response.'


# --- Flask Routes ---
@app.route('/')
def home():
    return "RAG Medical Assistant Backend is running!"

@app.route('/predict', methods=['POST'])
def predict():
    if not request.json or 'query' not in request.json:
        return jsonify({"error": "Invalid request. Please provide a JSON object with a 'query' field."}), 400

    user_query = request.json['query']
    print(f"Received query: {user_query}") # Log received queries

    response_text = generate_rag_response(user_query)

    return jsonify({"response": response_text})

# --- Run the App ---
if __name__ == '__main__':
    # In a production environment, use a production-ready server like Gunicorn
    # For local testing or simple Hugging Face Spaces deployment, this is sufficient
    # app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 5000)))
    app.run(debug=True) # Commented out to prevent blocking in notebook