import os from flask import Flask, request, jsonify from huggingface_hub import hf_hub_download from llama_cpp import Llama from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings from langchain_community.vectorstores import Chroma import warnings import logging # Suppress DeprecationWarnings warnings.filterwarnings("ignore", category=DeprecationWarning) # Set the logging level for the 'llama_cpp' logger to suppress informational messages logging.getLogger('llama_cpp').setLevel(logging.WARNING) app = Flask(__name__) # --- Configuration --- # Define the directory where the vector database is persisted VECTOR_DB_DIRECTORY = "merck_manuals" # Relative path from backend_files # Define the names/paths of the embedding model and LLM EMBEDDING_MODEL_NAME = "thenlper/gte-large" LLM_MODEL_NAME_OR_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" LLM_MODEL_BASENAME = "mistral-7b-instruct-v0.2.Q2_K.gguf" LLM_N_CTX = 2300 # Context window size, adjust based on your LLM and needs LLM_N_GPU_LAYERS = -1 # -1 to offload all layers to GPU if available, 0 for CPU LLM_N_BATCH = 512 LLM_N_THREADS = 4 # RAG Parameters (from Fine-Tuning Version 2 - Context-Rich + Safer LLM) RAG_K = 5 RAG_TEMPERATURE = 0.3 RAG_TOP_P = 0.95 RAG_TOP_K = 60 RAG_MAX_TOKENS = 200 # --- Load RAG Components --- # Load the embedding model try: embedding_model = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL_NAME) print("Embedding model loaded successfully.") except Exception as e: print(f"Error loading embedding model: {e}") embedding_model = None # Load the vector database try: # Ensure the vector database directory exists and is accessible if not os.path.exists(VECTOR_DB_DIRECTORY): print(f"Error: Vector database directory '{VECTOR_DB_DIRECTORY}' not found.") vectorstore = None else: vectorstore = Chroma(persist_directory=VECTOR_DB_DIRECTORY, embedding_function=embedding_model) print("Vector database loaded successfully.") except Exception as e: print(f"Error loading vector database: {e}") vectorstore = None # Create the retriever instance retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": RAG_K}) if vectorstore else None print(f"Retriever created with k={RAG_K}.") # Load the LLM llm = None try: print(f"Attempting to download and load LLM: {LLM_MODEL_BASENAME}") model_path = hf_hub_download(repo_id=LLM_MODEL_NAME_OR_PATH, filename=LLM_MODEL_BASENAME) llm = Llama( model_path=model_path, n_ctx=LLM_N_CTX, n_gpu_layers=LLM_N_GPU_LAYERS, n_batch=LLM_N_BATCH, n_threads=LLM_N_THREADS, verbose=False # Set to False for deployment ) print("LLM loaded successfully.") except Exception as e: print(f"Error loading LLM: {e}") llm = None # Ensure llm is None if loading fails # --- RAG Logic (Adapted from notebook) --- qna_system_message = """ You are a knowledgeable and reliable medical assistant. Your role is to provide accurate, concise, and up-to-date medical information based on trusted medical manuals and references. User will have the context required by you to answer the question. This context will always begin with the token ###Context This context will contain specific portions of the relevant document to answer the question. Always clarify if a condition requires a medical professional’s attention. If you don't know the answer, just say that you don't know. Don't try to make up an answer. Answer only based on your knowledge from medical manuals. Do not hallucinate or make up facts." """ qna_user_message_template = """ ### Context Here are some documents that are relevant to the question mentioned below: {context} ### Question {question}""" def generate_rag_response(user_input): if llm is None or retriever is None: return "Error: RAG components not loaded correctly." # Retrieve relevant document chunks try: relevant_document_chunks = retriever.get_relevant_documents(query=user_input) # k is set in retriever context_list = [d.page_content for d in relevant_document_chunks] context_for_query = ". ".join(context_list) except Exception as e: print(f"Error during document retrieval: {e}") return "Error: Could not retrieve relevant information." user_message = qna_user_message_template.replace('{context}', context_for_query) user_message = user_message.replace('{question}', user_input) prompt = qna_system_message + '\n' + user_message # Generate the response try: response = llm( prompt=prompt, max_tokens=RAG_MAX_TOKENS, temperature=RAG_TEMPERATURE, top_p=RAG_TOP_P, top_k=RAG_TOP_K, stop=['### Context'] # Add stop token if applicable ) # Extract and return the model's response, stripping leading/trailing whitespace return response['choices'][0]['text'].strip() except Exception as e: print(f"Error during LLM generation: {e}") return f'Sorry, I encountered an error generating the response.' # --- Flask Routes --- @app.route('/') def home(): return "RAG Medical Assistant Backend is running!" @app.route('/predict', methods=['POST']) def predict(): if not request.json or 'query' not in request.json: return jsonify({"error": "Invalid request. Please provide a JSON object with a 'query' field."}), 400 user_query = request.json['query'] print(f"Received query: {user_query}") # Log received queries response_text = generate_rag_response(user_query) return jsonify({"response": response_text}) # --- Run the App --- if __name__ == '__main__': # In a production environment, use a production-ready server like Gunicorn # For local testing or simple Hugging Face Spaces deployment, this is sufficient # app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 5000))) app.run(debug=True) # Commented out to prevent blocking in notebook