File size: 1,755 Bytes
6a1aee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5907de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import numpy as np
import keras
from PIL import Image
from keras.preprocessing import image as keras_image

# load the model
def load_model(model_path):
    model = keras.models.load_model(model_path)
    return model


# Function to preprocess the input image
def preprocess_image(image):
    # Convert image to grayscale
    image = np.array(image)
    image = Image.fromarray(image).convert('L')
    
    # Resize the image
    image = image.resize((128, 128))
    
    # Convert to numpy array and normalize
    image = np.array(image)
    image = image / 255.0
    
    # Add batch dimension
    image = np.expand_dims(image, axis=-1)
    
    # Stack the grayscale image to make it a 2-channel image
    image = np.repeat(image, 2, axis=-1)
    
    # Add batch dimension
    image = np.expand_dims(image, axis=0)
    
    return image

# Function to perform segmentation prediction
def predict_segmentation(model, image):
    segmentation_map = model.predict(image)[0]
    threshold = 0.5
    segmented_image = (segmentation_map > threshold).astype(np.uint8)
    segmented_image = Image.fromarray(segmented_image * 255)
    return segmented_image


# Define Gradio interface
def gradio_interface(model_path):
    # Load the model
    model = load_model(model_path)

    # Define input and output components
    demo = gr.Interface(lambda image: predict_segmentation(model, preprocess_image(image)),
                       inputs = "image" ,
                       outputs= "image" ,
                       title = "Brain Tumor Segmentation",
                       description = "Upload an image of a brain scan, and the model will segment the brain tumor.")
 
    demo.launch(share=True)

gradio_interface("model_x81_dcs65.h5")