File size: 11,462 Bytes
a9db6c2
c8ce899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import cv2
from ultralytics import YOLO
import sqlite3
import gradio as gr
import io
import base64
import pandas as pd
from scipy.spatial.distance import euclidean
from skimage.measure import regionprops

# Load YOLO segmentation model
try:
    yolo_model_glaucoma = YOLO('last.pt')
    print("YOLO model loaded successfully.")
except Exception as e:
    print(f"Error loading YOLO model: {e}")

def calculate_area(mask):
    area = np.sum(mask > 0.5)
    print(f"Calculated area: {area}")
    return area

def classify_ddls(rim_to_disc_ratio):
    if rim_to_disc_ratio >= 0.5:
        stage = 0  # Non Glaucomatous
    elif 0.4 <= rim_to_disc_ratio < 0.5:
        stage = 1
    elif 0.3 <= rim_to_disc_ratio < 0.4:
        stage = 2
    elif 0.2 <= rim_to_disc_ratio < 0.3:
        stage = 3
    elif 0.1 <= rim_to_disc_ratio < 0.2:
        stage = 4
    elif 0.0 < rim_to_disc_ratio < 0.1:
        stage = 5
    else:
        stage = 6
    print(f"Classified DDLS stage: {stage}")
    return stage

def add_watermark(image):
    try:
        logo = Image.open('image-logo.png').convert("RGBA")
        image = image.convert("RGBA")
        
        # Resize logo
        basewidth = 100
        wpercent = (basewidth / float(logo.size[0]))
        hsize = int((float(wpercent) * logo.size[1]))
        logo = logo.resize((basewidth, hsize), Image.LANCZOS)
        
        # Position logo
        position = (image.width - logo.width - 10, image.height - logo.height - 10)
        
        # Composite image
        transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0))
        transparent.paste(image, (0, 0))
        transparent.paste(logo, position, mask=logo)
        
        return transparent.convert("RGB")
    except Exception as e:
        print(f"Error adding watermark: {e}")
        return image

def fit_ellipse(mask):
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours) == 0:
        return None
    largest_contour = max(contours, key=cv2.contourArea)
    if len(largest_contour) < 5:
        return None
    ellipse = cv2.fitEllipse(largest_contour)
    return ellipse

def draw_ellipse(image, ellipse, color, thickness=2):
    if ellipse is not None:
        cv2.ellipse(image, ellipse, color, thickness)
    return image

def calculate_rim_to_disc_ratio(cup_ellipse, disk_ellipse, image):
    if cup_ellipse is None or disk_ellipse is None:
        return 0.0

    # Get center of the cup ellipse
    cup_center = (int(cup_ellipse[0][0]), int(cup_ellipse[0][1]))

    # Draw lines from cup center to disk edge
    rim_lengths = []
    disc_lengths = []

    for angle in np.arange(0, 360, 10):  # Sample every 10 degrees
        angle_rad = np.deg2rad(angle)
        direction = (np.cos(angle_rad), np.sin(angle_rad))

        # Find intersection points with disk ellipse
        disk_point = find_ellipse_intersection(disk_ellipse, cup_center, direction)

        if disk_point is not None:
            # Find intersection points with cup ellipse
            cup_point = find_ellipse_intersection(cup_ellipse, cup_center, direction)

            if cup_point is not None:
                rim_length = euclidean(cup_point, disk_point)
                disc_length = euclidean(cup_center, disk_point)
                rim_lengths.append(rim_length)
                disc_lengths.append(disc_length)

                # Draw lines for visualization
                cv2.line(image, cup_center, disk_point, (0, 255, 0), 1)  # Green line for rim
                cv2.line(image, cup_center, cup_point, (255, 0, 0), 1)  # Blue line for cup

    if len(rim_lengths) == 0 or len(disc_lengths) == 0:
        return 0.0

    # Calculate average rim-to-disc ratio
    rim_to_disc_ratio = np.mean(rim_lengths) / np.mean(disc_lengths)
    return rim_to_disc_ratio

def find_ellipse_intersection(ellipse, center, direction):
    # Unpack ellipse parameters
    (x, y), (MA, ma), angle = ellipse
    angle_rad = np.deg2rad(angle)

    # Transform direction to ellipse coordinate system
    dx, dy = direction
    dx_rot = dx * np.cos(-angle_rad) - dy * np.sin(-angle_rad)
    dy_rot = dx * np.sin(-angle_rad) + dy * np.cos(-angle_rad)

    # Find intersection point
    t = np.sqrt((MA / 2) ** 2 * (dx_rot ** 2) + (ma / 2) ** 2 * (dy_rot ** 2))
    if t == 0:
        return None

    x_intersect = int(x + dx * t)
    y_intersect = int(y + dy * t)
    return (x_intersect, y_intersect)

def predict_and_visualize_glaucoma(image, mask_threshold=0.5):
    try:
        pil_image = Image.fromarray(image)
        orig_size = pil_image.size
        results = yolo_model_glaucoma(pil_image)

        raw_response = str(results)
        print(f"YOLO results: {raw_response}")
        masked_image = np.array(pil_image)
        mask_image = np.zeros_like(masked_image)

        cup_mask, disk_mask = None, None

        if len(results) > 0:
            result = results[0]
            if hasattr(result, 'masks') and result.masks is not None and len(result.masks) > 0:
                for mask_data in result.masks.data:
                    mask = np.array(mask_data.cpu().squeeze().numpy())
                    mask_resized = cv2.resize(mask, orig_size, interpolation=cv2.INTER_NEAREST)

                    if np.sum(mask_resized) > np.sum(disk_mask if disk_mask is not None else 0):
                        cup_mask = disk_mask
                        disk_mask = mask_resized
                    else:
                        cup_mask = mask_resized

        if cup_mask is not None and disk_mask is not None:
            # Fit ellipses to the masks
            cup_ellipse = fit_ellipse(cup_mask)
            disk_ellipse = fit_ellipse(disk_mask)

            # Draw ellipses on the image
            combined_image = np.array(pil_image)
            combined_image = draw_ellipse(combined_image, cup_ellipse, (0, 0, 255), 2)  # Red for cup
            combined_image = draw_ellipse(combined_image, disk_ellipse, (255, 0, 0), 2)  # Blue for disk

            # Calculate rim-to-disc ratio using radial lines
            rim_to_disc_ratio = calculate_rim_to_disc_ratio(cup_ellipse, disk_ellipse, combined_image)
            ddls_stage = classify_ddls(rim_to_disc_ratio)

            # Add text to the image
            combined_pil_image = Image.fromarray(combined_image)
            draw = ImageDraw.Draw(combined_pil_image)
            
            # Load a larger font (adjust the size as needed)
            font_size = 48  # Example font size
            try:
                font = ImageFont.truetype("font.ttf", size=font_size)
            except IOError:
                font = ImageFont.load_default()
                print("Error: cannot open resource, using default font.")

            text = f"Rim to disc ratio: {rim_to_disc_ratio:.2f}\nDDLS stage: {ddls_stage}"
            text_x = 20
            text_y = 40

            draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font)

            # Add watermark
            combined_pil_image = add_watermark(combined_pil_image)

            return np.array(combined_pil_image), rim_to_disc_ratio, ddls_stage

        print("No detected regions")
        return np.zeros_like(image), 0.0, "No detected regions"
    except Exception as e:
        print("Error:", e)
        return np.zeros_like(image), 0.0, str(e)

def combined_prediction_glaucoma(image):
    segmented_image, rim_to_disc_ratio, ddls_stage = predict_and_visualize_glaucoma(image)
    print(f"Segmented image: {segmented_image.shape}")
    print(f"Rim to disc ratio: {rim_to_disc_ratio}, DDLS stage: {ddls_stage}")

    return segmented_image, rim_to_disc_ratio, ddls_stage

def save_prediction_to_db(image, rim_to_disc_ratio, ddls_stage):
    try:
        conn = sqlite3.connect('glaucoma_predictions.db')
        cursor = conn.cursor()
        
        # Create table if it does not exist
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS predictions (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            rim_to_disc_ratio REAL,
            ddls_stage INTEGER,
            image BLOB
        )
        ''')

        # Convert PIL image to binary
        image_io = io.BytesIO()
        image.save(image_io, format='PNG')
        image_binary = image_io.getvalue()

        # Insert prediction into the database
        cursor.execute('''
        INSERT INTO predictions (rim_to_disc_ratio, ddls_stage, image)
        VALUES (?, ?, ?)
        ''', (rim_to_disc_ratio, ddls_stage, image_binary))

        conn.commit()
        conn.close()
        return "Values successfully saved to database.", ""
    except Exception as e:
        print(f"Error saving to database: {e}")
        return f"Error saving to database: {e}", ""

def view_predictions_from_db():
    try:
        conn = sqlite3.connect('glaucoma_predictions.db')
        cursor = conn.cursor()
        cursor.execute("SELECT * FROM predictions")
        predictions = cursor.fetchall()
        conn.close()

        # Create a DataFrame for better visualization
        df = pd.DataFrame(predictions, columns=["ID", "Rim to Disc Ratio", "DDLS Stage", "Image"])
        
        # Convert binary image data to displayable format
        df['Image'] = df['Image'].apply(lambda x: "data:image/png;base64," + base64.b64encode(x).decode("utf-8"))
        
        return df
    except Exception as e:
        print(f"Error viewing database: {e}")
        return f"Error viewing database: {e}"

def display_predictions():
    df = view_predictions_from_db()
    
    if isinstance(df, str):
        return df
    
    # Convert DataFrame to HTML with images
    df_html = df.to_html(escape=False, formatters=dict(Image=lambda x: f'<img src="{x}" width="100">'))
    
    return df_html

def process_and_save_image(image):
    segmented_image, rim_to_disc_ratio, ddls_stage = combined_prediction_glaucoma(image)
    pil_segmented_image = Image.fromarray(segmented_image)
    status, error = save_prediction_to_db(pil_segmented_image, rim_to_disc_ratio, ddls_stage)
    return segmented_image, rim_to_disc_ratio, ddls_stage, status, error

with gr.Blocks() as demo:
    with gr.Tabs():
        with gr.TabItem("Predict and Save"):
            with gr.Row():
                input_image = gr.Image(label="Upload Fundus Image")
                output_image = gr.Image(label="Segmented Image")
            with gr.Row():
                rim_to_disc_ratio_output = gr.Textbox(label="Rim to Disc Ratio")
                ddls_stage_output = gr.Textbox(label="DDLS Stage")
            with gr.Row():
                status_output = gr.Textbox(label="Status")
                error_output = gr.Textbox(label="Error")
            
            predict_and_save = gr.Button("Predict and Save")
            predict_and_save.click(
                process_and_save_image, 
                inputs=[input_image], 
                outputs=[
                    output_image, rim_to_disc_ratio_output, ddls_stage_output, status_output, error_output
                ]
            )
        
        with gr.TabItem("View Predictions"):
            view_button = gr.Button("View Predictions")
            predictions_output = gr.HTML()
            view_button.click(
                fn=display_predictions,
                inputs=None,
                outputs=predictions_output
            )

demo.launch()