import gradio as gr import numpy as np import pandas as pd from PIL import Image import logging import os # Import os for path checks # Placeholder imports for clustering and utils. # In a real scenario, these files (clustering.py, utils.py) # would contain your actual implementation. try: import clustering import utils except ImportError as e: logging.error(f"Error importing helper modules: {e}. Using dummy functions.") # Define dummy functions if imports fail, to allow the app to launch. class DummyClustering: def get_centroids(self, *args, **kwargs): logging.warning("Using dummy get_centroids. Provide actual clustering.py.") # Return some dummy centroids for demonstration # In a real scenario, you might want to raise an error or return an empty list # if clustering is critical for app functionality. return [(100, 100), (200, 200)] class DummyUtils: def prepare_classifier_input(self, *args, **kwargs): logging.warning("Using dummy prepare_classifier_input. Provide actual utils.py.") # Return dummy data for model input return np.zeros((1, 250, 250, 3)) # Example shape, adjust as per your model input def show_boxes(self, image, damage_sites, save_image=False, image_path=None): logging.warning("Using dummy show_boxes. Provide actual utils.py.") # Return the original image for dummy display # In a real app, this would draw boxes if image is None: return Image.new('RGB', (400, 400), color = 'red') # Placeholder if no image provided return image clustering = DummyClustering() utils = DummyUtils() from tensorflow import keras # --- Basic Setup --- logging.getLogger().setLevel(logging.INFO) # --- Constants and Model Loading --- IMAGE_PATH = "classified_damage_sites.png" CSV_PATH = "classified_damage_sites.csv" # Load models once at startup to improve performance model1 = None model2 = None try: # Check if model files exist before attempting to load if os.path.exists('rwthmaterials_dp800_network1_inclusion.h5'): model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5') logging.info("Model 1 loaded successfully.") else: logging.warning("Model 1 (rwthmaterials_dp800_network1_inclusion.h5) not found. Classification results may be inaccurate.") if os.path.exists('rwthmaterials_dp800_network2_damage.h5'): model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5') logging.info("Model 2 loaded successfully.") else: logging.warning("Model 2 (rwthmaterials_dp800_network2_damage.h5) not found. Classification results may be inaccurate.") except Exception as e: logging.error(f"Error loading models: {e}") # Models are set to None, and warnings/errors are logged. # The app will still attempt to launch. damage_classes = {3: "Martensite", 2: "Interface", 0: "Notch", 1: "Shadowing"} model1_windowsize = [250, 250] model2_windowsize = [100, 100] # --- Core Processing Function (Your original logic) --- def damage_classification(SEM_image, image_threshold, model1_threshold, model2_threshold): """ This function contains the core scientific logic for classifying damage sites. It returns the classified image and paths to the output files. """ if SEM_image is None: raise gr.Error("Please upload an SEM Image before running classification.") if model1 is None or model2 is None: raise gr.Error("Models not loaded. Please ensure model files are present and valid.") damage_sites = {} # Step 1: Clustering to find damage centroids # Ensure clustering.get_centroids handles the case of no centroids found all_centroids = clustering.get_centroids( SEM_image, image_threshold=image_threshold, fill_holes=True, filter_close_centroids=True, ) for c in all_centroids: damage_sites[(c[0], c[1])] = "Not Classified" # Step 2: Model 1 to identify inclusions if len(all_centroids) > 0: try: images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize) y1_pred = model1.predict(np.asarray(images_model1, dtype=float)) inclusions = np.where(y1_pred[:, 0] > model1_threshold)[0] for idx in inclusions: coord = all_centroids[idx] damage_sites[(coord[0], coord[1])] = "Inclusion" except Exception as e: logging.error(f"Error during Model 1 prediction: {e}") # Step 3: Model 2 to classify remaining damage types centroids_model2 = [list(k) for k, v in damage_sites.items() if v == "Not Classified"] if centroids_model2: try: images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize) y2_pred = model2.predict(np.asarray(images_model2, dtype=float)) # Adjust the thresholding for damage_index to handle potential empty results damage_index = np.asarray(y2_pred > model2_threshold).nonzero() for i in range(len(damage_index[0])): sample_idx = damage_index[0][i] class_idx = damage_index[1][i] label = damage_classes.get(class_idx, "Unknown") coord = centroids_model2[sample_idx] damage_sites[(coord[0], coord[1])] = label except Exception as e: logging.error(f"Error during Model 2 prediction: {e}") # Step 4: Draw boxes on image and save output image # The utils.show_boxes function is assumed to return a PIL Image object image_with_boxes = utils.show_boxes(SEM_image, damage_sites, save_image=True, image_path=IMAGE_PATH) # Step 5: Export CSV file data = [[x, y, label] for (x, y), label in damage_sites.items()] df = pd.DataFrame(data, columns=["x", "y", "damage_type"]) df.to_csv(CSV_PATH, index=False) # Log file paths to ensure they are correct logging.info(f"Generated Image Path: {IMAGE_PATH}") logging.info(f"Generated CSV Path: {CSV_PATH}") return image_with_boxes, IMAGE_PATH, CSV_PATH # --- Gradio Interface Definition --- with gr.Blocks() as app: gr.Markdown("# Damage Classification in Dual Phase Steels") gr.Markdown("Upload a Scanning Electron Microscope (SEM) image and set the thresholds to classify material damage.") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload SEM Image") cluster_threshold_input = gr.Number(value=20, label="Image Binarization Threshold") model1_threshold_input = gr.Number(value=0.7, label="Inclusion Model Certainty (0-1)") model2_threshold_input = gr.Number(value=0.5, label="Damage Model Certainty (0-1)") classify_btn = gr.Button("Run Classification", variant="primary") with gr.Column(scale=2): output_image = gr.Image(label="Classified Image") # Initialize DownloadButtons as hidden. They will become visible after a successful run. # Explicitly setting value=None to be safe, though visible=False should imply it. download_image_btn = gr.DownloadButton(label="Download Image", value=None, visible=False) download_csv_btn = gr.DownloadButton(label="Download CSV", value=None, visible=False) # This wrapper function handles the UI updates, which is the robust way to use Gradio. def run_classification_and_update_ui(sem_image, cluster_thresh, m1_thresh, m2_thresh): """ Calls the core logic and then returns updates for the Gradio UI components. """ try: # Call the main processing function classified_img, img_path, csv_path = damage_classification(sem_image, cluster_thresh, m1_thresh, m2_thresh) # Return the results in the correct order to update the output components. # Use gr.update to change properties of a component, like visibility and value. return ( classified_img, gr.update(value=img_path, visible=True), gr.update(value=csv_path, visible=True) ) except Exception as e: # Catch any error during classification and display it gracefully logging.error(f"Error during classification: {e}") gr.Warning(f"An error occurred: {e}") # Keep download buttons hidden on error and clear image return ( None, # Clear the image on error gr.update(visible=False), gr.update(visible=False) ) # Connect the button's click event to the wrapper function classify_btn.click( fn=run_classification_and_update_ui, inputs=[ image_input, cluster_threshold_input, model1_threshold_input, model2_threshold_input ], outputs=[ output_image, download_image_btn, download_csv_btn ], ) if __name__ == "__main__": app.launch()