File size: 7,349 Bytes
e9a3ab1
afcc3c4
c4a671a
afcc3c4
 
 
 
 
 
 
 
 
 
 
 
 
 
e9a3ab1
 
 
afcc3c4
 
 
 
 
 
 
 
 
 
 
 
c4a671a
 
afcc3c4
 
 
e9a3ab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afcc3c4
 
 
c4a671a
e9a3ab1
c4a671a
 
 
e9a3ab1
 
c4a671a
afcc3c4
 
 
 
 
e9a3ab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afcc3c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9a3ab1
afcc3c4
 
 
 
e9a3ab1
afcc3c4
 
 
 
e9a3ab1
 
 
 
afcc3c4
e9a3ab1
afcc3c4
e9a3ab1
 
 
afcc3c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9a3ab1
 
afcc3c4
 
 
e9a3ab1
afcc3c4
 
e9a3ab1
afcc3c4
 
 
 
e9a3ab1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

import os
import sys

# --- Imports for Eye Disease Model (Newly Added) ---
import numpy as np
from dotenv import load_dotenv
from flask import Flask, jsonify, render_template, request
from keras.models import load_model
from keras.preprocessing import image
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_pinecone import PineconeVectorStore
from werkzeug.utils import secure_filename

# --- Imports for Chatbot (Existing) ---
from src.helpers import download_hugging_face_embeddings

app = Flask(__name__)
load_dotenv()

# --- Configuration for Chatbot (Existing) ---
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
GITHUB_AI_ENDPOINT = "https://models.github.ai/inference"
GITHUB_AI_MODEL_NAME = "Phi-3-small-8k-instruct"
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY

# --- Configuration for Eye Disease Model (Newly Added) ---
app.config['UPLOAD_FOLDER'] = 'static/uploads/'
# --- FIX: Use an absolute path to ensure the model is found ---
MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'abhi_model.h5')

# --- Load Models and Setup Chains (Existing and New) ---
# 1. Load Chatbot RAG Chain (Existing)
try:
    print("Initializing chatbot components...")
    embeddings = download_hugging_face_embeddings()
    index_name = "medicalbot"
    docsearch = PineconeVectorStore.from_existing_index(
        index_name=index_name,
        embedding=embeddings
    )
    retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
    llm = ChatOpenAI(
        temperature=0.4,
        max_tokens=500,
        model=GITHUB_AI_MODEL_NAME,
        openai_api_key=GITHUB_TOKEN,
        openai_api_base=GITHUB_AI_ENDPOINT
    )
    system_prompt = (
        "You are a helpful medical assistant. Use the retrieved information to answer the question concisely and accurately. "
        "If you are asked to describe a medical condition, explain what it is, its common symptoms, and general causes in a way that is easy for a non-medical person to understand. "
        "Always include a disclaimer that the user should consult a qualified healthcare professional for a real diagnosis."
        "\n\n"
        "{context}"
    )
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            ("human", "{input}"),
        ]
    )
    Youtube_chain = create_stuff_documents_chain(llm, prompt)
    rag_chain = create_retrieval_chain(retriever, Youtube_chain)
    print("Chatbot components initialized successfully.")
except Exception as e:
    print(f"Error initializing chatbot components: {e}", file=sys.stderr)
    sys.exit(1)


# 2. Load Eye Disease Prediction Model (Newly Added)
try:
    print(f"Attempting to load eye disease model from path: {MODEL_PATH}")
    eye_disease_model = load_model(MODEL_PATH)
    print("Successfully loaded eye disease model.")
except Exception as e:
    print(f"Error loading model from path: {MODEL_PATH}", file=sys.stderr)
    print(f"Details: {e}", file=sys.stderr)
    sys.exit(1) # Exit if the model cannot be loaded


# --- Prediction Helper Function (Newly Added) ---
def predict_eye_disease(img_path):
    """Predicts the eye disease from an image path."""
    print(f"Starting prediction for image: {img_path}")
    try:
        img = image.load_img(img_path, target_size=(256, 256))
        img_array = image.img_to_array(img)
        img_array = img_array / 255.0  # Normalize
        img_array = np.expand_dims(img_array, axis=0)
        
        # Check the input shape before prediction
        print(f"Input image shape for model: {img_array.shape}")
        
        rnn_input = np.zeros((img_array.shape[0], 512))
        
        # Make the prediction
        prediction = eye_disease_model.predict([img_array, rnn_input])
        print("Prediction successful.")
        
        class_names = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
        predicted_class_index = np.argmax(prediction)
        predicted_class = class_names[predicted_class_index]
        
        print(f"Predicted class: {predicted_class}")
        return predicted_class
    except Exception as e:
        print(f"Error in predict_eye_disease function: {e}", file=sys.stderr)
        raise # Re-raise the exception to be caught in the route handler


# --- Flask Routes ---

@app.route("/")
def index():
    return render_template('chat.html')


@app.route("/get", methods=["POST"])
def chat():
    """Handles text-based queries for the RAG chatbot."""
    msg = request.form["msg"]
    print(f"Received message: {msg}")
    response = rag_chain.invoke({"input": msg})
    print(f"Bot response: {response['answer']}")
    return str(response["answer"])


@app.route('/predict_disease', methods=['POST'])
def upload_file():
    """
    Handles image uploads, predicts the disease, and then uses the LLM
    to generate a description for the predicted disease.
    """
    if 'file' not in request.files:
        print("Error: No file part in the request.")
        return jsonify({'error': 'No file part in the request'}), 400
    
    file = request.files['file']
    if file.filename == '':
        print("Error: No file selected for uploading.")
        return jsonify({'error': 'No file selected for uploading'}), 400

    if file:
        filename = secure_filename(file.filename)
        # Create the uploads directory if it doesn't exist
        if not os.path.exists(app.config['UPLOAD_FOLDER']):
            os.makedirs(app.config['UPLOAD_FOLDER'])
        
        file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        
        try:
            print(f"Saving uploaded file to: {file_path}")
            file.save(file_path)
            
            # 1. Get prediction from the vision model
            predicted_class_name = predict_eye_disease(file_path)
            
            # 2. Create a prompt for the LLM based on the prediction
            llm_prompt = f"Please provide a description for the eye condition: {predicted_class_name.replace('_', ' ').title()}."
            
            # 3. Call the RAG chain (LLM) to get a dynamic description
            print(f"Sending prompt to LLM: {llm_prompt}")
            llm_response = rag_chain.invoke({"input": llm_prompt})
            description = llm_response['answer']
            
            # 4. Return JSON response to the frontend
            return jsonify({
                'prediction': predicted_class_name.replace('_', ' ').title(),
                'description': description
            })
        except Exception as e:
            print(f"Uncaught exception in '/predict_disease' route: {e}", file=sys.stderr)
            return jsonify({'error': 'Failed to process image due to a server error.'}), 500
        finally:
            # Clean up the uploaded file
            if os.path.exists(file_path):
                print(f"Removing temporary file: {file_path}")
                os.remove(file_path)

    print("Error: Unknown error occurred.")
    return jsonify({'error': 'Unknown error occurred'}), 500


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)