Spaces:
Runtime error
Runtime error
| import os | |
| from flask import Flask, request, jsonify | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| import warnings | |
| import logging | |
| # Suppress DeprecationWarnings | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| # Set the logging level for the 'llama_cpp' logger to suppress informational messages | |
| logging.getLogger('llama_cpp').setLevel(logging.WARNING) | |
| app = Flask(__name__) | |
| # --- Configuration --- | |
| # Define the directory where the vector database is persisted | |
| VECTOR_DB_DIRECTORY = "merck_manuals" # Relative path from backend_files | |
| # Define the names/paths of the embedding model and LLM | |
| EMBEDDING_MODEL_NAME = "thenlper/gte-large" | |
| LLM_MODEL_NAME_OR_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" | |
| LLM_MODEL_BASENAME = "mistral-7b-instruct-v0.2.Q2_K.gguf" | |
| LLM_N_CTX = 2300 # Context window size, adjust based on your LLM and needs | |
| LLM_N_GPU_LAYERS = -1 # -1 to offload all layers to GPU if available, 0 for CPU | |
| LLM_N_BATCH = 512 | |
| LLM_N_THREADS = 4 | |
| # RAG Parameters (from Fine-Tuning Version 2 - Context-Rich + Safer LLM) | |
| RAG_K = 5 | |
| RAG_TEMPERATURE = 0.3 | |
| RAG_TOP_P = 0.95 | |
| RAG_TOP_K = 60 | |
| RAG_MAX_TOKENS = 200 | |
| # --- Load RAG Components --- | |
| # Load the embedding model | |
| try: | |
| embedding_model = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
| print("Embedding model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading embedding model: {e}") | |
| embedding_model = None | |
| # Load the vector database | |
| try: | |
| # Ensure the vector database directory exists and is accessible | |
| if not os.path.exists(VECTOR_DB_DIRECTORY): | |
| print(f"Error: Vector database directory '{VECTOR_DB_DIRECTORY}' not found.") | |
| vectorstore = None | |
| else: | |
| vectorstore = Chroma(persist_directory=VECTOR_DB_DIRECTORY, embedding_function=embedding_model) | |
| print("Vector database loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading vector database: {e}") | |
| vectorstore = None | |
| # Create the retriever instance | |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": RAG_K}) if vectorstore else None | |
| print(f"Retriever created with k={RAG_K}.") | |
| # Load the LLM | |
| llm = None | |
| try: | |
| print(f"Attempting to download and load LLM: {LLM_MODEL_BASENAME}") | |
| model_path = hf_hub_download(repo_id=LLM_MODEL_NAME_OR_PATH, filename=LLM_MODEL_BASENAME) | |
| llm = Llama( | |
| model_path=model_path, | |
| n_ctx=LLM_N_CTX, | |
| n_gpu_layers=LLM_N_GPU_LAYERS, | |
| n_batch=LLM_N_BATCH, | |
| n_threads=LLM_N_THREADS, | |
| verbose=False # Set to False for deployment | |
| ) | |
| print("LLM loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading LLM: {e}") | |
| llm = None # Ensure llm is None if loading fails | |
| # --- RAG Logic (Adapted from notebook) --- | |
| qna_system_message = """ | |
| You are a knowledgeable and reliable medical assistant. | |
| Your role is to provide accurate, concise, and up-to-date medical information | |
| based on trusted medical manuals and references. | |
| User will have the context required by you to answer the question. | |
| This context will always begin with the token ###Context | |
| This context will contain specific portions of the relevant document to answer the question. | |
| Always clarify if a condition requires a medical professional’s attention. | |
| If you don't know the answer, just say that you don't know. | |
| Don't try to make up an answer. | |
| Answer only based on your knowledge from medical manuals. Do not hallucinate or make up facts." | |
| """ | |
| qna_user_message_template = """ | |
| ### Context | |
| Here are some documents that are relevant to the question mentioned below: | |
| {context} | |
| ### Question | |
| {question}""" | |
| def generate_rag_response(user_input): | |
| if llm is None or retriever is None: | |
| return "Error: RAG components not loaded correctly." | |
| # Retrieve relevant document chunks | |
| try: | |
| relevant_document_chunks = retriever.get_relevant_documents(query=user_input) # k is set in retriever | |
| context_list = [d.page_content for d in relevant_document_chunks] | |
| context_for_query = ". ".join(context_list) | |
| except Exception as e: | |
| print(f"Error during document retrieval: {e}") | |
| return "Error: Could not retrieve relevant information." | |
| user_message = qna_user_message_template.replace('{context}', context_for_query) | |
| user_message = user_message.replace('{question}', user_input) | |
| prompt = qna_system_message + '\n' + user_message | |
| # Generate the response | |
| try: | |
| response = llm( | |
| prompt=prompt, | |
| max_tokens=RAG_MAX_TOKENS, | |
| temperature=RAG_TEMPERATURE, | |
| top_p=RAG_TOP_P, | |
| top_k=RAG_TOP_K, | |
| stop=['### Context'] # Add stop token if applicable | |
| ) | |
| # Extract and return the model's response, stripping leading/trailing whitespace | |
| return response['choices'][0]['text'].strip() | |
| except Exception as e: | |
| print(f"Error during LLM generation: {e}") | |
| return f'Sorry, I encountered an error generating the response.' | |
| # --- Flask Routes --- | |
| def home(): | |
| return "RAG Medical Assistant Backend is running!" | |
| def predict(): | |
| if not request.json or 'query' not in request.json: | |
| return jsonify({"error": "Invalid request. Please provide a JSON object with a 'query' field."}), 400 | |
| user_query = request.json['query'] | |
| print(f"Received query: {user_query}") # Log received queries | |
| response_text = generate_rag_response(user_query) | |
| return jsonify({"response": response_text}) | |
| # --- Run the App --- | |
| if __name__ == '__main__': | |
| # In a production environment, use a production-ready server like Gunicorn | |
| # For local testing or simple Hugging Face Spaces deployment, this is sufficient | |
| # app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 5000))) | |
| app.run(debug=True) # Commented out to prevent blocking in notebook | |