| | import gradio as gr |
| | import numpy as np |
| | import pandas as pd |
| | from PIL import Image |
| | import logging |
| | import os |
| |
|
| | |
| | |
| | |
| | try: |
| | import clustering |
| | import utils |
| | except ImportError as e: |
| | logging.error(f"Error importing helper modules: {e}. Using dummy functions.") |
| | |
| | class DummyClustering: |
| | def get_centroids(self, *args, **kwargs): |
| | logging.warning("Using dummy get_centroids. Provide actual clustering.py.") |
| | |
| | |
| | |
| | return [(100, 100), (200, 200)] |
| |
|
| | class DummyUtils: |
| | def prepare_classifier_input(self, *args, **kwargs): |
| | logging.warning("Using dummy prepare_classifier_input. Provide actual utils.py.") |
| | |
| | return np.zeros((1, 250, 250, 3)) |
| |
|
| | def show_boxes(self, image, damage_sites, save_image=False, image_path=None): |
| | logging.warning("Using dummy show_boxes. Provide actual utils.py.") |
| | |
| | |
| | if image is None: |
| | return Image.new('RGB', (400, 400), color = 'red') |
| | return image |
| |
|
| | clustering = DummyClustering() |
| | utils = DummyUtils() |
| |
|
| |
|
| | from tensorflow import keras |
| |
|
| | |
| | logging.getLogger().setLevel(logging.INFO) |
| |
|
| | |
| | IMAGE_PATH = "classified_damage_sites.png" |
| | CSV_PATH = "classified_damage_sites.csv" |
| |
|
| | |
| | model1 = None |
| | model2 = None |
| |
|
| | try: |
| | |
| | if os.path.exists('rwthmaterials_dp800_network1_inclusion.h5'): |
| | model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5') |
| | logging.info("Model 1 loaded successfully.") |
| | else: |
| | logging.warning("Model 1 (rwthmaterials_dp800_network1_inclusion.h5) not found. Classification results may be inaccurate.") |
| |
|
| | if os.path.exists('rwthmaterials_dp800_network2_damage.h5'): |
| | model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5') |
| | logging.info("Model 2 loaded successfully.") |
| | else: |
| | logging.warning("Model 2 (rwthmaterials_dp800_network2_damage.h5) not found. Classification results may be inaccurate.") |
| |
|
| | except Exception as e: |
| | logging.error(f"Error loading models: {e}") |
| | |
| | |
| |
|
| | damage_classes = {3: "Martensite", 2: "Interface", 0: "Notch", 1: "Shadowing"} |
| | model1_windowsize = [250, 250] |
| | model2_windowsize = [100, 100] |
| |
|
| | |
| | def damage_classification(SEM_image, image_threshold, model1_threshold, model2_threshold): |
| | """ |
| | This function contains the core scientific logic for classifying damage sites. |
| | It returns the classified image and paths to the output files. |
| | """ |
| | if SEM_image is None: |
| | raise gr.Error("Please upload an SEM Image before running classification.") |
| |
|
| | if model1 is None or model2 is None: |
| | raise gr.Error("Models not loaded. Please ensure model files are present and valid.") |
| |
|
| | damage_sites = {} |
| | |
| | |
| | |
| | all_centroids = clustering.get_centroids( |
| | SEM_image, |
| | image_threshold=image_threshold, |
| | fill_holes=True, |
| | filter_close_centroids=True, |
| | ) |
| | |
| | for c in all_centroids: |
| | damage_sites[(c[0], c[1])] = "Not Classified" |
| |
|
| | |
| | if len(all_centroids) > 0: |
| | try: |
| | images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize) |
| | y1_pred = model1.predict(np.asarray(images_model1, dtype=float)) |
| | inclusions = np.where(y1_pred[:, 0] > model1_threshold)[0] |
| | for idx in inclusions: |
| | coord = all_centroids[idx] |
| | damage_sites[(coord[0], coord[1])] = "Inclusion" |
| | except Exception as e: |
| | logging.error(f"Error during Model 1 prediction: {e}") |
| |
|
| | |
| | centroids_model2 = [list(k) for k, v in damage_sites.items() if v == "Not Classified"] |
| | if centroids_model2: |
| | try: |
| | images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize) |
| | y2_pred = model2.predict(np.asarray(images_model2, dtype=float)) |
| | |
| | damage_index = np.asarray(y2_pred > model2_threshold).nonzero() |
| | |
| | for i in range(len(damage_index[0])): |
| | sample_idx = damage_index[0][i] |
| | class_idx = damage_index[1][i] |
| | label = damage_classes.get(class_idx, "Unknown") |
| | coord = centroids_model2[sample_idx] |
| | damage_sites[(coord[0], coord[1])] = label |
| | except Exception as e: |
| | logging.error(f"Error during Model 2 prediction: {e}") |
| |
|
| | |
| | |
| | image_with_boxes = utils.show_boxes(SEM_image, damage_sites, save_image=True, image_path=IMAGE_PATH) |
| |
|
| | |
| | data = [[x, y, label] for (x, y), label in damage_sites.items()] |
| | df = pd.DataFrame(data, columns=["x", "y", "damage_type"]) |
| | df.to_csv(CSV_PATH, index=False) |
| |
|
| | |
| | logging.info(f"Generated Image Path: {IMAGE_PATH}") |
| | logging.info(f"Generated CSV Path: {CSV_PATH}") |
| |
|
| | return image_with_boxes, IMAGE_PATH, CSV_PATH |
| |
|
| | |
| | with gr.Blocks() as app: |
| | gr.Markdown("# Damage Classification in Dual Phase Steels") |
| | gr.Markdown("Upload a Scanning Electron Microscope (SEM) image and set the thresholds to classify material damage.") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | image_input = gr.Image(type="pil", label="Upload SEM Image") |
| | cluster_threshold_input = gr.Number(value=20, label="Image Binarization Threshold") |
| | model1_threshold_input = gr.Number(value=0.7, label="Inclusion Model Certainty (0-1)") |
| | model2_threshold_input = gr.Number(value=0.5, label="Damage Model Certainty (0-1)") |
| | classify_btn = gr.Button("Run Classification", variant="primary") |
| | with gr.Column(scale=2): |
| | output_image = gr.Image(label="Classified Image") |
| | |
| | |
| | download_image_btn = gr.DownloadButton(label="Download Image", value=None, visible=False) |
| | download_csv_btn = gr.DownloadButton(label="Download CSV", value=None, visible=False) |
| |
|
| | |
| | def run_classification_and_update_ui(sem_image, cluster_thresh, m1_thresh, m2_thresh): |
| | """ |
| | Calls the core logic and then returns updates for the Gradio UI components. |
| | """ |
| | try: |
| | |
| | classified_img, img_path, csv_path = damage_classification(sem_image, cluster_thresh, m1_thresh, m2_thresh) |
| | |
| | |
| | |
| | return ( |
| | classified_img, |
| | gr.update(value=img_path, visible=True), |
| | gr.update(value=csv_path, visible=True) |
| | ) |
| | except Exception as e: |
| | |
| | logging.error(f"Error during classification: {e}") |
| | gr.Warning(f"An error occurred: {e}") |
| | |
| | return ( |
| | None, |
| | gr.update(visible=False), |
| | gr.update(visible=False) |
| | ) |
| |
|
| | |
| | classify_btn.click( |
| | fn=run_classification_and_update_ui, |
| | inputs=[ |
| | image_input, |
| | cluster_threshold_input, |
| | model1_threshold_input, |
| | model2_threshold_input |
| | ], |
| | outputs=[ |
| | output_image, |
| | download_image_btn, |
| | download_csv_btn |
| | ], |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | app.launch() |
| |
|