import os import sys # --- Imports for Eye Disease Model (Newly Added) --- import numpy as np from dotenv import load_dotenv from flask import Flask, jsonify, render_template, request from keras.models import load_model from keras.preprocessing import image from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain_pinecone import PineconeVectorStore from werkzeug.utils import secure_filename # --- Imports for Chatbot (Existing) --- from src.helpers import download_hugging_face_embeddings app = Flask(__name__) load_dotenv() # --- Configuration for Chatbot (Existing) --- PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN") GITHUB_AI_ENDPOINT = "https://models.github.ai/inference" GITHUB_AI_MODEL_NAME = "Phi-3-small-8k-instruct" os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY # --- Configuration for Eye Disease Model (Newly Added) --- app.config['UPLOAD_FOLDER'] = 'static/uploads/' # --- FIX: Use an absolute path to ensure the model is found --- MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'abhi_model.h5') # --- Load Models and Setup Chains (Existing and New) --- # 1. Load Chatbot RAG Chain (Existing) try: print("Initializing chatbot components...") embeddings = download_hugging_face_embeddings() index_name = "medicalbot" docsearch = PineconeVectorStore.from_existing_index( index_name=index_name, embedding=embeddings ) retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3}) llm = ChatOpenAI( temperature=0.4, max_tokens=500, model=GITHUB_AI_MODEL_NAME, openai_api_key=GITHUB_TOKEN, openai_api_base=GITHUB_AI_ENDPOINT ) system_prompt = ( "You are a helpful medical assistant. Use the retrieved information to answer the question concisely and accurately. " "If you are asked to describe a medical condition, explain what it is, its common symptoms, and general causes in a way that is easy for a non-medical person to understand. " "Always include a disclaimer that the user should consult a qualified healthcare professional for a real diagnosis." "\n\n" "{context}" ) prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), ("human", "{input}"), ] ) Youtube_chain = create_stuff_documents_chain(llm, prompt) rag_chain = create_retrieval_chain(retriever, Youtube_chain) print("Chatbot components initialized successfully.") except Exception as e: print(f"Error initializing chatbot components: {e}", file=sys.stderr) sys.exit(1) # 2. Load Eye Disease Prediction Model (Newly Added) try: print(f"Attempting to load eye disease model from path: {MODEL_PATH}") eye_disease_model = load_model(MODEL_PATH) print("Successfully loaded eye disease model.") except Exception as e: print(f"Error loading model from path: {MODEL_PATH}", file=sys.stderr) print(f"Details: {e}", file=sys.stderr) sys.exit(1) # Exit if the model cannot be loaded # --- Prediction Helper Function (Newly Added) --- def predict_eye_disease(img_path): """Predicts the eye disease from an image path.""" print(f"Starting prediction for image: {img_path}") try: img = image.load_img(img_path, target_size=(256, 256)) img_array = image.img_to_array(img) img_array = img_array / 255.0 # Normalize img_array = np.expand_dims(img_array, axis=0) # Check the input shape before prediction print(f"Input image shape for model: {img_array.shape}") rnn_input = np.zeros((img_array.shape[0], 512)) # Make the prediction prediction = eye_disease_model.predict([img_array, rnn_input]) print("Prediction successful.") class_names = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal'] predicted_class_index = np.argmax(prediction) predicted_class = class_names[predicted_class_index] print(f"Predicted class: {predicted_class}") return predicted_class except Exception as e: print(f"Error in predict_eye_disease function: {e}", file=sys.stderr) raise # Re-raise the exception to be caught in the route handler # --- Flask Routes --- @app.route("/") def index(): return render_template('chat.html') @app.route("/get", methods=["POST"]) def chat(): """Handles text-based queries for the RAG chatbot.""" msg = request.form["msg"] print(f"Received message: {msg}") response = rag_chain.invoke({"input": msg}) print(f"Bot response: {response['answer']}") return str(response["answer"]) @app.route('/predict_disease', methods=['POST']) def upload_file(): """ Handles image uploads, predicts the disease, and then uses the LLM to generate a description for the predicted disease. """ if 'file' not in request.files: print("Error: No file part in the request.") return jsonify({'error': 'No file part in the request'}), 400 file = request.files['file'] if file.filename == '': print("Error: No file selected for uploading.") return jsonify({'error': 'No file selected for uploading'}), 400 if file: filename = secure_filename(file.filename) # Create the uploads directory if it doesn't exist if not os.path.exists(app.config['UPLOAD_FOLDER']): os.makedirs(app.config['UPLOAD_FOLDER']) file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) try: print(f"Saving uploaded file to: {file_path}") file.save(file_path) # 1. Get prediction from the vision model predicted_class_name = predict_eye_disease(file_path) # 2. Create a prompt for the LLM based on the prediction llm_prompt = f"Please provide a description for the eye condition: {predicted_class_name.replace('_', ' ').title()}." # 3. Call the RAG chain (LLM) to get a dynamic description print(f"Sending prompt to LLM: {llm_prompt}") llm_response = rag_chain.invoke({"input": llm_prompt}) description = llm_response['answer'] # 4. Return JSON response to the frontend return jsonify({ 'prediction': predicted_class_name.replace('_', ' ').title(), 'description': description }) except Exception as e: print(f"Uncaught exception in '/predict_disease' route: {e}", file=sys.stderr) return jsonify({'error': 'Failed to process image due to a server error.'}), 500 finally: # Clean up the uploaded file if os.path.exists(file_path): print(f"Removing temporary file: {file_path}") os.remove(file_path) print("Error: Unknown error occurred.") return jsonify({'error': 'Unknown error occurred'}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)