Spaces:
Paused
Paused
| import os | |
| import sys | |
| # --- Imports for Eye Disease Model (Newly Added) --- | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| from flask import Flask, jsonify, render_template, request | |
| from keras.models import load_model | |
| from keras.preprocessing import image | |
| from langchain.chains import create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_openai import ChatOpenAI | |
| from langchain_pinecone import PineconeVectorStore | |
| from werkzeug.utils import secure_filename | |
| # --- Imports for Chatbot (Existing) --- | |
| from src.helpers import download_hugging_face_embeddings | |
| app = Flask(__name__) | |
| load_dotenv() | |
| # --- Configuration for Chatbot (Existing) --- | |
| PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") | |
| GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN") | |
| GITHUB_AI_ENDPOINT = "https://models.github.ai/inference" | |
| GITHUB_AI_MODEL_NAME = "Phi-3-small-8k-instruct" | |
| os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY | |
| # --- Configuration for Eye Disease Model (Newly Added) --- | |
| app.config['UPLOAD_FOLDER'] = 'static/uploads/' | |
| # --- FIX: Use an absolute path to ensure the model is found --- | |
| MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'abhi_model.h5') | |
| # --- Load Models and Setup Chains (Existing and New) --- | |
| # 1. Load Chatbot RAG Chain (Existing) | |
| try: | |
| print("Initializing chatbot components...") | |
| embeddings = download_hugging_face_embeddings() | |
| index_name = "medicalbot" | |
| docsearch = PineconeVectorStore.from_existing_index( | |
| index_name=index_name, | |
| embedding=embeddings | |
| ) | |
| retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| llm = ChatOpenAI( | |
| temperature=0.4, | |
| max_tokens=500, | |
| model=GITHUB_AI_MODEL_NAME, | |
| openai_api_key=GITHUB_TOKEN, | |
| openai_api_base=GITHUB_AI_ENDPOINT | |
| ) | |
| system_prompt = ( | |
| "You are a helpful medical assistant. Use the retrieved information to answer the question concisely and accurately. " | |
| "If you are asked to describe a medical condition, explain what it is, its common symptoms, and general causes in a way that is easy for a non-medical person to understand. " | |
| "Always include a disclaimer that the user should consult a qualified healthcare professional for a real diagnosis." | |
| "\n\n" | |
| "{context}" | |
| ) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| Youtube_chain = create_stuff_documents_chain(llm, prompt) | |
| rag_chain = create_retrieval_chain(retriever, Youtube_chain) | |
| print("Chatbot components initialized successfully.") | |
| except Exception as e: | |
| print(f"Error initializing chatbot components: {e}", file=sys.stderr) | |
| sys.exit(1) | |
| # 2. Load Eye Disease Prediction Model (Newly Added) | |
| try: | |
| print(f"Attempting to load eye disease model from path: {MODEL_PATH}") | |
| eye_disease_model = load_model(MODEL_PATH) | |
| print("Successfully loaded eye disease model.") | |
| except Exception as e: | |
| print(f"Error loading model from path: {MODEL_PATH}", file=sys.stderr) | |
| print(f"Details: {e}", file=sys.stderr) | |
| sys.exit(1) # Exit if the model cannot be loaded | |
| # --- Prediction Helper Function (Newly Added) --- | |
| def predict_eye_disease(img_path): | |
| """Predicts the eye disease from an image path.""" | |
| print(f"Starting prediction for image: {img_path}") | |
| try: | |
| img = image.load_img(img_path, target_size=(256, 256)) | |
| img_array = image.img_to_array(img) | |
| img_array = img_array / 255.0 # Normalize | |
| img_array = np.expand_dims(img_array, axis=0) | |
| # Check the input shape before prediction | |
| print(f"Input image shape for model: {img_array.shape}") | |
| rnn_input = np.zeros((img_array.shape[0], 512)) | |
| # Make the prediction | |
| prediction = eye_disease_model.predict([img_array, rnn_input]) | |
| print("Prediction successful.") | |
| class_names = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal'] | |
| predicted_class_index = np.argmax(prediction) | |
| predicted_class = class_names[predicted_class_index] | |
| print(f"Predicted class: {predicted_class}") | |
| return predicted_class | |
| except Exception as e: | |
| print(f"Error in predict_eye_disease function: {e}", file=sys.stderr) | |
| raise # Re-raise the exception to be caught in the route handler | |
| # --- Flask Routes --- | |
| def index(): | |
| return render_template('chat.html') | |
| def chat(): | |
| """Handles text-based queries for the RAG chatbot.""" | |
| msg = request.form["msg"] | |
| print(f"Received message: {msg}") | |
| response = rag_chain.invoke({"input": msg}) | |
| print(f"Bot response: {response['answer']}") | |
| return str(response["answer"]) | |
| def upload_file(): | |
| """ | |
| Handles image uploads, predicts the disease, and then uses the LLM | |
| to generate a description for the predicted disease. | |
| """ | |
| if 'file' not in request.files: | |
| print("Error: No file part in the request.") | |
| return jsonify({'error': 'No file part in the request'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| print("Error: No file selected for uploading.") | |
| return jsonify({'error': 'No file selected for uploading'}), 400 | |
| if file: | |
| filename = secure_filename(file.filename) | |
| # Create the uploads directory if it doesn't exist | |
| if not os.path.exists(app.config['UPLOAD_FOLDER']): | |
| os.makedirs(app.config['UPLOAD_FOLDER']) | |
| file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| try: | |
| print(f"Saving uploaded file to: {file_path}") | |
| file.save(file_path) | |
| # 1. Get prediction from the vision model | |
| predicted_class_name = predict_eye_disease(file_path) | |
| # 2. Create a prompt for the LLM based on the prediction | |
| llm_prompt = f"Please provide a description for the eye condition: {predicted_class_name.replace('_', ' ').title()}." | |
| # 3. Call the RAG chain (LLM) to get a dynamic description | |
| print(f"Sending prompt to LLM: {llm_prompt}") | |
| llm_response = rag_chain.invoke({"input": llm_prompt}) | |
| description = llm_response['answer'] | |
| # 4. Return JSON response to the frontend | |
| return jsonify({ | |
| 'prediction': predicted_class_name.replace('_', ' ').title(), | |
| 'description': description | |
| }) | |
| except Exception as e: | |
| print(f"Uncaught exception in '/predict_disease' route: {e}", file=sys.stderr) | |
| return jsonify({'error': 'Failed to process image due to a server error.'}), 500 | |
| finally: | |
| # Clean up the uploaded file | |
| if os.path.exists(file_path): | |
| print(f"Removing temporary file: {file_path}") | |
| os.remove(file_path) | |
| print("Error: Unknown error occurred.") | |
| return jsonify({'error': 'Unknown error occurred'}), 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) | |