OcuAI / app.py
abhikamuni's picture
update
e9a3ab1
import os
import sys
# --- Imports for Eye Disease Model (Newly Added) ---
import numpy as np
from dotenv import load_dotenv
from flask import Flask, jsonify, render_template, request
from keras.models import load_model
from keras.preprocessing import image
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_pinecone import PineconeVectorStore
from werkzeug.utils import secure_filename
# --- Imports for Chatbot (Existing) ---
from src.helpers import download_hugging_face_embeddings
app = Flask(__name__)
load_dotenv()
# --- Configuration for Chatbot (Existing) ---
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
GITHUB_AI_ENDPOINT = "https://models.github.ai/inference"
GITHUB_AI_MODEL_NAME = "Phi-3-small-8k-instruct"
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
# --- Configuration for Eye Disease Model (Newly Added) ---
app.config['UPLOAD_FOLDER'] = 'static/uploads/'
# --- FIX: Use an absolute path to ensure the model is found ---
MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'abhi_model.h5')
# --- Load Models and Setup Chains (Existing and New) ---
# 1. Load Chatbot RAG Chain (Existing)
try:
print("Initializing chatbot components...")
embeddings = download_hugging_face_embeddings()
index_name = "medicalbot"
docsearch = PineconeVectorStore.from_existing_index(
index_name=index_name,
embedding=embeddings
)
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
llm = ChatOpenAI(
temperature=0.4,
max_tokens=500,
model=GITHUB_AI_MODEL_NAME,
openai_api_key=GITHUB_TOKEN,
openai_api_base=GITHUB_AI_ENDPOINT
)
system_prompt = (
"You are a helpful medical assistant. Use the retrieved information to answer the question concisely and accurately. "
"If you are asked to describe a medical condition, explain what it is, its common symptoms, and general causes in a way that is easy for a non-medical person to understand. "
"Always include a disclaimer that the user should consult a qualified healthcare professional for a real diagnosis."
"\n\n"
"{context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
Youtube_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, Youtube_chain)
print("Chatbot components initialized successfully.")
except Exception as e:
print(f"Error initializing chatbot components: {e}", file=sys.stderr)
sys.exit(1)
# 2. Load Eye Disease Prediction Model (Newly Added)
try:
print(f"Attempting to load eye disease model from path: {MODEL_PATH}")
eye_disease_model = load_model(MODEL_PATH)
print("Successfully loaded eye disease model.")
except Exception as e:
print(f"Error loading model from path: {MODEL_PATH}", file=sys.stderr)
print(f"Details: {e}", file=sys.stderr)
sys.exit(1) # Exit if the model cannot be loaded
# --- Prediction Helper Function (Newly Added) ---
def predict_eye_disease(img_path):
"""Predicts the eye disease from an image path."""
print(f"Starting prediction for image: {img_path}")
try:
img = image.load_img(img_path, target_size=(256, 256))
img_array = image.img_to_array(img)
img_array = img_array / 255.0 # Normalize
img_array = np.expand_dims(img_array, axis=0)
# Check the input shape before prediction
print(f"Input image shape for model: {img_array.shape}")
rnn_input = np.zeros((img_array.shape[0], 512))
# Make the prediction
prediction = eye_disease_model.predict([img_array, rnn_input])
print("Prediction successful.")
class_names = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
predicted_class_index = np.argmax(prediction)
predicted_class = class_names[predicted_class_index]
print(f"Predicted class: {predicted_class}")
return predicted_class
except Exception as e:
print(f"Error in predict_eye_disease function: {e}", file=sys.stderr)
raise # Re-raise the exception to be caught in the route handler
# --- Flask Routes ---
@app.route("/")
def index():
return render_template('chat.html')
@app.route("/get", methods=["POST"])
def chat():
"""Handles text-based queries for the RAG chatbot."""
msg = request.form["msg"]
print(f"Received message: {msg}")
response = rag_chain.invoke({"input": msg})
print(f"Bot response: {response['answer']}")
return str(response["answer"])
@app.route('/predict_disease', methods=['POST'])
def upload_file():
"""
Handles image uploads, predicts the disease, and then uses the LLM
to generate a description for the predicted disease.
"""
if 'file' not in request.files:
print("Error: No file part in the request.")
return jsonify({'error': 'No file part in the request'}), 400
file = request.files['file']
if file.filename == '':
print("Error: No file selected for uploading.")
return jsonify({'error': 'No file selected for uploading'}), 400
if file:
filename = secure_filename(file.filename)
# Create the uploads directory if it doesn't exist
if not os.path.exists(app.config['UPLOAD_FOLDER']):
os.makedirs(app.config['UPLOAD_FOLDER'])
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
try:
print(f"Saving uploaded file to: {file_path}")
file.save(file_path)
# 1. Get prediction from the vision model
predicted_class_name = predict_eye_disease(file_path)
# 2. Create a prompt for the LLM based on the prediction
llm_prompt = f"Please provide a description for the eye condition: {predicted_class_name.replace('_', ' ').title()}."
# 3. Call the RAG chain (LLM) to get a dynamic description
print(f"Sending prompt to LLM: {llm_prompt}")
llm_response = rag_chain.invoke({"input": llm_prompt})
description = llm_response['answer']
# 4. Return JSON response to the frontend
return jsonify({
'prediction': predicted_class_name.replace('_', ' ').title(),
'description': description
})
except Exception as e:
print(f"Uncaught exception in '/predict_disease' route: {e}", file=sys.stderr)
return jsonify({'error': 'Failed to process image due to a server error.'}), 500
finally:
# Clean up the uploaded file
if os.path.exists(file_path):
print(f"Removing temporary file: {file_path}")
os.remove(file_path)
print("Error: Unknown error occurred.")
return jsonify({'error': 'Unknown error occurred'}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)