Varun6299's picture
Upload folder using huggingface_hub
c27c0dc verified
import os
from flask import Flask, request, jsonify
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
import warnings
import logging
# Suppress DeprecationWarnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Set the logging level for the 'llama_cpp' logger to suppress informational messages
logging.getLogger('llama_cpp').setLevel(logging.WARNING)
app = Flask(__name__)
# --- Configuration ---
# Define the directory where the vector database is persisted
VECTOR_DB_DIRECTORY = "merck_manuals" # Relative path from backend_files
# Define the names/paths of the embedding model and LLM
EMBEDDING_MODEL_NAME = "thenlper/gte-large"
LLM_MODEL_NAME_OR_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
LLM_MODEL_BASENAME = "mistral-7b-instruct-v0.2.Q2_K.gguf"
LLM_N_CTX = 2300 # Context window size, adjust based on your LLM and needs
LLM_N_GPU_LAYERS = -1 # -1 to offload all layers to GPU if available, 0 for CPU
LLM_N_BATCH = 512
LLM_N_THREADS = 4
# RAG Parameters (from Fine-Tuning Version 2 - Context-Rich + Safer LLM)
RAG_K = 5
RAG_TEMPERATURE = 0.3
RAG_TOP_P = 0.95
RAG_TOP_K = 60
RAG_MAX_TOKENS = 200
# --- Load RAG Components ---
# Load the embedding model
try:
embedding_model = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL_NAME)
print("Embedding model loaded successfully.")
except Exception as e:
print(f"Error loading embedding model: {e}")
embedding_model = None
# Load the vector database
try:
# Ensure the vector database directory exists and is accessible
if not os.path.exists(VECTOR_DB_DIRECTORY):
print(f"Error: Vector database directory '{VECTOR_DB_DIRECTORY}' not found.")
vectorstore = None
else:
vectorstore = Chroma(persist_directory=VECTOR_DB_DIRECTORY, embedding_function=embedding_model)
print("Vector database loaded successfully.")
except Exception as e:
print(f"Error loading vector database: {e}")
vectorstore = None
# Create the retriever instance
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": RAG_K}) if vectorstore else None
print(f"Retriever created with k={RAG_K}.")
# Load the LLM
llm = None
try:
print(f"Attempting to download and load LLM: {LLM_MODEL_BASENAME}")
model_path = hf_hub_download(repo_id=LLM_MODEL_NAME_OR_PATH, filename=LLM_MODEL_BASENAME)
llm = Llama(
model_path=model_path,
n_ctx=LLM_N_CTX,
n_gpu_layers=LLM_N_GPU_LAYERS,
n_batch=LLM_N_BATCH,
n_threads=LLM_N_THREADS,
verbose=False # Set to False for deployment
)
print("LLM loaded successfully.")
except Exception as e:
print(f"Error loading LLM: {e}")
llm = None # Ensure llm is None if loading fails
# --- RAG Logic (Adapted from notebook) ---
qna_system_message = """
You are a knowledgeable and reliable medical assistant.
Your role is to provide accurate, concise, and up-to-date medical information
based on trusted medical manuals and references.
User will have the context required by you to answer the question.
This context will always begin with the token ###Context
This context will contain specific portions of the relevant document to answer the question.
Always clarify if a condition requires a medical professional’s attention.
If you don't know the answer, just say that you don't know.
Don't try to make up an answer.
Answer only based on your knowledge from medical manuals. Do not hallucinate or make up facts."
"""
qna_user_message_template = """
### Context
Here are some documents that are relevant to the question mentioned below:
{context}
### Question
{question}"""
def generate_rag_response(user_input):
if llm is None or retriever is None:
return "Error: RAG components not loaded correctly."
# Retrieve relevant document chunks
try:
relevant_document_chunks = retriever.get_relevant_documents(query=user_input) # k is set in retriever
context_list = [d.page_content for d in relevant_document_chunks]
context_for_query = ". ".join(context_list)
except Exception as e:
print(f"Error during document retrieval: {e}")
return "Error: Could not retrieve relevant information."
user_message = qna_user_message_template.replace('{context}', context_for_query)
user_message = user_message.replace('{question}', user_input)
prompt = qna_system_message + '\n' + user_message
# Generate the response
try:
response = llm(
prompt=prompt,
max_tokens=RAG_MAX_TOKENS,
temperature=RAG_TEMPERATURE,
top_p=RAG_TOP_P,
top_k=RAG_TOP_K,
stop=['### Context'] # Add stop token if applicable
)
# Extract and return the model's response, stripping leading/trailing whitespace
return response['choices'][0]['text'].strip()
except Exception as e:
print(f"Error during LLM generation: {e}")
return f'Sorry, I encountered an error generating the response.'
# --- Flask Routes ---
@app.route('/')
def home():
return "RAG Medical Assistant Backend is running!"
@app.route('/predict', methods=['POST'])
def predict():
if not request.json or 'query' not in request.json:
return jsonify({"error": "Invalid request. Please provide a JSON object with a 'query' field."}), 400
user_query = request.json['query']
print(f"Received query: {user_query}") # Log received queries
response_text = generate_rag_response(user_query)
return jsonify({"response": response_text})
# --- Run the App ---
if __name__ == '__main__':
# In a production environment, use a production-ready server like Gunicorn
# For local testing or simple Hugging Face Spaces deployment, this is sufficient
# app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 5000)))
app.run(debug=True) # Commented out to prevent blocking in notebook