Spaces:
Paused
Paused
File size: 7,349 Bytes
e9a3ab1 afcc3c4 c4a671a afcc3c4 e9a3ab1 afcc3c4 c4a671a afcc3c4 e9a3ab1 afcc3c4 c4a671a e9a3ab1 c4a671a e9a3ab1 c4a671a afcc3c4 e9a3ab1 afcc3c4 e9a3ab1 afcc3c4 e9a3ab1 afcc3c4 e9a3ab1 afcc3c4 e9a3ab1 afcc3c4 e9a3ab1 afcc3c4 e9a3ab1 afcc3c4 e9a3ab1 afcc3c4 e9a3ab1 afcc3c4 e9a3ab1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
import sys
# --- Imports for Eye Disease Model (Newly Added) ---
import numpy as np
from dotenv import load_dotenv
from flask import Flask, jsonify, render_template, request
from keras.models import load_model
from keras.preprocessing import image
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_pinecone import PineconeVectorStore
from werkzeug.utils import secure_filename
# --- Imports for Chatbot (Existing) ---
from src.helpers import download_hugging_face_embeddings
app = Flask(__name__)
load_dotenv()
# --- Configuration for Chatbot (Existing) ---
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
GITHUB_AI_ENDPOINT = "https://models.github.ai/inference"
GITHUB_AI_MODEL_NAME = "Phi-3-small-8k-instruct"
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
# --- Configuration for Eye Disease Model (Newly Added) ---
app.config['UPLOAD_FOLDER'] = 'static/uploads/'
# --- FIX: Use an absolute path to ensure the model is found ---
MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'abhi_model.h5')
# --- Load Models and Setup Chains (Existing and New) ---
# 1. Load Chatbot RAG Chain (Existing)
try:
print("Initializing chatbot components...")
embeddings = download_hugging_face_embeddings()
index_name = "medicalbot"
docsearch = PineconeVectorStore.from_existing_index(
index_name=index_name,
embedding=embeddings
)
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
llm = ChatOpenAI(
temperature=0.4,
max_tokens=500,
model=GITHUB_AI_MODEL_NAME,
openai_api_key=GITHUB_TOKEN,
openai_api_base=GITHUB_AI_ENDPOINT
)
system_prompt = (
"You are a helpful medical assistant. Use the retrieved information to answer the question concisely and accurately. "
"If you are asked to describe a medical condition, explain what it is, its common symptoms, and general causes in a way that is easy for a non-medical person to understand. "
"Always include a disclaimer that the user should consult a qualified healthcare professional for a real diagnosis."
"\n\n"
"{context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
Youtube_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, Youtube_chain)
print("Chatbot components initialized successfully.")
except Exception as e:
print(f"Error initializing chatbot components: {e}", file=sys.stderr)
sys.exit(1)
# 2. Load Eye Disease Prediction Model (Newly Added)
try:
print(f"Attempting to load eye disease model from path: {MODEL_PATH}")
eye_disease_model = load_model(MODEL_PATH)
print("Successfully loaded eye disease model.")
except Exception as e:
print(f"Error loading model from path: {MODEL_PATH}", file=sys.stderr)
print(f"Details: {e}", file=sys.stderr)
sys.exit(1) # Exit if the model cannot be loaded
# --- Prediction Helper Function (Newly Added) ---
def predict_eye_disease(img_path):
"""Predicts the eye disease from an image path."""
print(f"Starting prediction for image: {img_path}")
try:
img = image.load_img(img_path, target_size=(256, 256))
img_array = image.img_to_array(img)
img_array = img_array / 255.0 # Normalize
img_array = np.expand_dims(img_array, axis=0)
# Check the input shape before prediction
print(f"Input image shape for model: {img_array.shape}")
rnn_input = np.zeros((img_array.shape[0], 512))
# Make the prediction
prediction = eye_disease_model.predict([img_array, rnn_input])
print("Prediction successful.")
class_names = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
predicted_class_index = np.argmax(prediction)
predicted_class = class_names[predicted_class_index]
print(f"Predicted class: {predicted_class}")
return predicted_class
except Exception as e:
print(f"Error in predict_eye_disease function: {e}", file=sys.stderr)
raise # Re-raise the exception to be caught in the route handler
# --- Flask Routes ---
@app.route("/")
def index():
return render_template('chat.html')
@app.route("/get", methods=["POST"])
def chat():
"""Handles text-based queries for the RAG chatbot."""
msg = request.form["msg"]
print(f"Received message: {msg}")
response = rag_chain.invoke({"input": msg})
print(f"Bot response: {response['answer']}")
return str(response["answer"])
@app.route('/predict_disease', methods=['POST'])
def upload_file():
"""
Handles image uploads, predicts the disease, and then uses the LLM
to generate a description for the predicted disease.
"""
if 'file' not in request.files:
print("Error: No file part in the request.")
return jsonify({'error': 'No file part in the request'}), 400
file = request.files['file']
if file.filename == '':
print("Error: No file selected for uploading.")
return jsonify({'error': 'No file selected for uploading'}), 400
if file:
filename = secure_filename(file.filename)
# Create the uploads directory if it doesn't exist
if not os.path.exists(app.config['UPLOAD_FOLDER']):
os.makedirs(app.config['UPLOAD_FOLDER'])
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
try:
print(f"Saving uploaded file to: {file_path}")
file.save(file_path)
# 1. Get prediction from the vision model
predicted_class_name = predict_eye_disease(file_path)
# 2. Create a prompt for the LLM based on the prediction
llm_prompt = f"Please provide a description for the eye condition: {predicted_class_name.replace('_', ' ').title()}."
# 3. Call the RAG chain (LLM) to get a dynamic description
print(f"Sending prompt to LLM: {llm_prompt}")
llm_response = rag_chain.invoke({"input": llm_prompt})
description = llm_response['answer']
# 4. Return JSON response to the frontend
return jsonify({
'prediction': predicted_class_name.replace('_', ' ').title(),
'description': description
})
except Exception as e:
print(f"Uncaught exception in '/predict_disease' route: {e}", file=sys.stderr)
return jsonify({'error': 'Failed to process image due to a server error.'}), 500
finally:
# Clean up the uploaded file
if os.path.exists(file_path):
print(f"Removing temporary file: {file_path}")
os.remove(file_path)
print("Error: Unknown error occurred.")
return jsonify({'error': 'Unknown error occurred'}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)
|