File size: 7,334 Bytes
af3fe62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ad9cb6
af3fe62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gradio as gr
import numpy as np
import pandas as pd

# our own helper tools
import clustering
import utils

import logging
logging.getLogger().setLevel(logging.INFO)

from tensorflow import keras

#image_threshold = 20
damage_sites = {}

model1_windowsize = [250,250]
#model1_threshold = 0.7

model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5')
model1.compile()

damage_classes = {3: "Martensite",2: "Interface",0:"Notch",1:"Shadowing"}

model2_windowsize = [100,100]
#model2_threshold = 0.5

model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5')
model2.compile()



##
## Function to do the actual damage classification
##
def damage_classification(SEM_image,image_threshold, model1_threshold, model2_threshold):


    ##
    ## clustering
    ##
    logging.debug('---------------: clustering :=====================')
    all_centroids = clustering.get_centroids(SEM_image, image_threshold=image_threshold,
                                            fill_holes=True, filter_close_centroids=True)
    
    for i in range(len(all_centroids)) :
        key = (all_centroids[i][0],all_centroids[i][1])
        damage_sites[key] = 'Not Classified'
    
    ##
    ## Inclusions vs the rest
    ##
    logging.debug('---------------: prepare model 1 :=====================')
    images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
    
    logging.debug('---------------: run model 1 :=====================')
    y1_pred = model1.predict(np.asarray(images_model1, float))

    logging.debug('---------------: model1 threshold :=====================')
    inclusions = y1_pred[:,0].reshape(len(y1_pred),1)
    inclusions = np.where(inclusions > model1_threshold)

    logging.debug('---------------: model 1 update dict :=====================')
    for i in range(len(inclusions[0])):
        centroid_id = inclusions[0][i]
        coordinates = all_centroids[centroid_id]
        key = (coordinates[0], coordinates[1])
        damage_sites[key] = 'Inclusion'
    logging.debug('Damage sites after model 1')
    logging.debug(damage_sites)

    ##
    ## Martensite cracking, etc
    ##
    logging.debug('---------------: prepare model 2 :=====================')
    centroids_model2 = []
    for key, value in damage_sites.items():
        if value == 'Not Classified':
            coordinates = list([key[0],key[1]])
            centroids_model2.append(coordinates)
    logging.debug('Centroids model 2')
    logging.debug(centroids_model2)

    logging.debug('---------------: prepare model 2 :=====================')
    images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize)
    logging.debug('Images model 2')
    logging.debug(images_model2)
    
    logging.debug('---------------: run model 2 :=====================')
    y2_pred = model2.predict(np.asarray(images_model2, float))

    damage_index = np.asarray(y2_pred > model2_threshold).nonzero()


    for i in range(len(damage_index[0])):
        index = damage_index[0][i]
        identified_class = damage_index[1][i]
        label = damage_classes[identified_class]
        coordinates = centroids_model2[index]
        #print('Damage {} \t identified as {}, \t coordinates {}'.format(i, label, coordinates))   
        key = (coordinates[0], coordinates[1])
        damage_sites[key] = label

    ##
    ## show the damage sites on the image
    ##
    logging.debug("-----------------: final damage sites :=================")
    logging.debug(damage_sites)

    image_path = 'classified_damage_sites.png'
    image = utils.show_boxes(SEM_image, damage_sites, 
                             save_image=True, 
                             image_path=image_path)

    ##
    ## export data
    ##
    csv_path = 'classified_damage_sites.csv'
    cols = ['x', 'y', 'damage_type']
    
    data = []
    for key, value in damage_sites.items():
        data.append([key[0], key[1], value])

    df = pd.DataFrame(columns=cols, data=data)
    
    df.to_csv(csv_path)


    return  image, image_path, csv_path

## ---------------------------------------------------------------------------------------------------------------
## main app interface
## -----------------------------------------------------------------------------------------------------------------
with gr.Blocks() as app:
    gr.Markdown('# Damage Classification in Dual Phase Steels')
    gr.Markdown('This app classifies damage types in dual phase steels. Two models are used. The first model is used to identify inclusions in the steel. The second model is used to identify the remaining damage types: Martensite cracking, Interface Decohesion, Notch effect and Shadows.')

    gr.Markdown('If you use this app, kindly cite the following papers:')
    gr.Markdown('Kusche, C., Reclik, T., Freund, M., Al-Samman, T., Kerzel, U., & Korte-Kerzel, S. (2019). Large-area, high-resolution characterisation and classification of damage mechanisms in dual-phase steel using deep learning. PloS one, 14(5), e0216493. [Link](https://doi.org/10.1371/journal.pone.0216493)')
    gr.Markdown('Medghalchi, S., Kusche, C. F., Karimi, E., Kerzel, U., & Korte-Kerzel, S. (2020). Damage analysis in dual-phase steel using deep learning: transfer from uniaxial to biaxial straining conditions by image data augmentation. Jom, 72, 4420-4430. [Link](https://link.springer.com/article/10.1007/s11837-020-04404-0)')

    image_input = gr.Image()
    with gr.Row():
        cluster_threshold_input = gr.Number(label='Cluster Threshold', value = 20, 
                                            info='Grayscale value at which a pixel is attributed to a potential damage site')
        model1_threshold_input = gr.Number(label='Model 1 Threshold', value = 0.7, info='Threshold for the model identifying inclusions')   
        model2_threshold_input = gr.Number(label='Model 2 Threshold', value = 0.5, info='Thrshold for the model identifying the remaining damage types')


    button = gr.Button("Classify")

    
    output_image = gr.Image()
    with gr.Row():
        download_image = gr.DownloadButton(label='Download Image')
        download_csv = gr.DownloadButton(label='Download Damage List')

    
    button.click(damage_classification, 
                 inputs=[image_input, cluster_threshold_input, model1_threshold_input, model2_threshold_input],
                 outputs=[output_image, download_image, download_csv])



# simple interface, no title, etc
# app = gr.Interface(damage_classification, 
#                    inputs=[gr.Image(),
#                             gr.Number(label='Cluster Threshold', value = 20, info='Grayscale value at which a pixel is attributed to a potential damage site'),    
#                             gr.Number(label='Model 1 Threshold', value = 0.7, info='Threshold for the model identifying inclusions'),
#                             gr.Number(label='Model 2 Threshold', value = 0.5, info='Thrshold for the model identifying the remaining damage types')
#                    ], 
#                    outputs=[gr.Image(), 
#                             gr.DownloadButton(label='Download Image'),
#                             gr.DownloadButton(label='Download Damage List')])
if __name__ == "__main__":
    app.launch()