kerzel's picture
another verssion from gemini
c836a9e
raw
history blame
9.32 kB
import gradio as gr
import numpy as np
import pandas as pd
from PIL import Image
import logging
import os # Import os for path checks
# Placeholder imports for clustering and utils.
# In a real scenario, these files (clustering.py, utils.py)
# would contain your actual implementation.
try:
import clustering
import utils
except ImportError as e:
logging.error(f"Error importing helper modules: {e}. Using dummy functions.")
# Define dummy functions if imports fail, to allow the app to launch.
class DummyClustering:
def get_centroids(self, *args, **kwargs):
logging.warning("Using dummy get_centroids. Provide actual clustering.py.")
# Return some dummy centroids for demonstration
# In a real scenario, you might want to raise an error or return an empty list
# if clustering is critical for app functionality.
return [(100, 100), (200, 200)]
class DummyUtils:
def prepare_classifier_input(self, *args, **kwargs):
logging.warning("Using dummy prepare_classifier_input. Provide actual utils.py.")
# Return dummy data for model input
return np.zeros((1, 250, 250, 3)) # Example shape, adjust as per your model input
def show_boxes(self, image, damage_sites, save_image=False, image_path=None):
logging.warning("Using dummy show_boxes. Provide actual utils.py.")
# Return the original image for dummy display
# In a real app, this would draw boxes
if image is None:
return Image.new('RGB', (400, 400), color = 'red') # Placeholder if no image provided
return image
clustering = DummyClustering()
utils = DummyUtils()
from tensorflow import keras
# --- Basic Setup ---
logging.getLogger().setLevel(logging.INFO)
# --- Constants and Model Loading ---
IMAGE_PATH = "classified_damage_sites.png"
CSV_PATH = "classified_damage_sites.csv"
# Load models once at startup to improve performance
model1 = None
model2 = None
try:
# Check if model files exist before attempting to load
if os.path.exists('rwthmaterials_dp800_network1_inclusion.h5'):
model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5')
logging.info("Model 1 loaded successfully.")
else:
logging.warning("Model 1 (rwthmaterials_dp800_network1_inclusion.h5) not found. Classification results may be inaccurate.")
if os.path.exists('rwthmaterials_dp800_network2_damage.h5'):
model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5')
logging.info("Model 2 loaded successfully.")
else:
logging.warning("Model 2 (rwthmaterials_dp800_network2_damage.h5) not found. Classification results may be inaccurate.")
except Exception as e:
logging.error(f"Error loading models: {e}")
# Models are set to None, and warnings/errors are logged.
# The app will still attempt to launch.
damage_classes = {3: "Martensite", 2: "Interface", 0: "Notch", 1: "Shadowing"}
model1_windowsize = [250, 250]
model2_windowsize = [100, 100]
# --- Core Processing Function (Your original logic) ---
def damage_classification(SEM_image, image_threshold, model1_threshold, model2_threshold):
"""
This function contains the core scientific logic for classifying damage sites.
It returns the classified image and paths to the output files.
"""
if SEM_image is None:
raise gr.Error("Please upload an SEM Image before running classification.")
if model1 is None or model2 is None:
raise gr.Error("Models not loaded. Please ensure model files are present and valid.")
damage_sites = {}
# Step 1: Clustering to find damage centroids
# Ensure clustering.get_centroids handles the case of no centroids found
all_centroids = clustering.get_centroids(
SEM_image,
image_threshold=image_threshold,
fill_holes=True,
filter_close_centroids=True,
)
for c in all_centroids:
damage_sites[(c[0], c[1])] = "Not Classified"
# Step 2: Model 1 to identify inclusions
if len(all_centroids) > 0:
try:
images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
y1_pred = model1.predict(np.asarray(images_model1, dtype=float))
inclusions = np.where(y1_pred[:, 0] > model1_threshold)[0]
for idx in inclusions:
coord = all_centroids[idx]
damage_sites[(coord[0], coord[1])] = "Inclusion"
except Exception as e:
logging.error(f"Error during Model 1 prediction: {e}")
# Step 3: Model 2 to classify remaining damage types
centroids_model2 = [list(k) for k, v in damage_sites.items() if v == "Not Classified"]
if centroids_model2:
try:
images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize)
y2_pred = model2.predict(np.asarray(images_model2, dtype=float))
# Adjust the thresholding for damage_index to handle potential empty results
damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
for i in range(len(damage_index[0])):
sample_idx = damage_index[0][i]
class_idx = damage_index[1][i]
label = damage_classes.get(class_idx, "Unknown")
coord = centroids_model2[sample_idx]
damage_sites[(coord[0], coord[1])] = label
except Exception as e:
logging.error(f"Error during Model 2 prediction: {e}")
# Step 4: Draw boxes on image and save output image
# The utils.show_boxes function is assumed to return a PIL Image object
image_with_boxes = utils.show_boxes(SEM_image, damage_sites, save_image=True, image_path=IMAGE_PATH)
# Step 5: Export CSV file
data = [[x, y, label] for (x, y), label in damage_sites.items()]
df = pd.DataFrame(data, columns=["x", "y", "damage_type"])
df.to_csv(CSV_PATH, index=False)
# Log file paths to ensure they are correct
logging.info(f"Generated Image Path: {IMAGE_PATH}")
logging.info(f"Generated CSV Path: {CSV_PATH}")
return image_with_boxes, IMAGE_PATH, CSV_PATH
# --- Gradio Interface Definition ---
with gr.Blocks() as app:
gr.Markdown("# Damage Classification in Dual Phase Steels")
gr.Markdown("Upload a Scanning Electron Microscope (SEM) image and set the thresholds to classify material damage.")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload SEM Image")
cluster_threshold_input = gr.Number(value=20, label="Image Binarization Threshold")
model1_threshold_input = gr.Number(value=0.7, label="Inclusion Model Certainty (0-1)")
model2_threshold_input = gr.Number(value=0.5, label="Damage Model Certainty (0-1)")
classify_btn = gr.Button("Run Classification", variant="primary")
with gr.Column(scale=2):
output_image = gr.Image(label="Classified Image")
# Initialize DownloadButtons as hidden. They will become visible after a successful run.
# Explicitly setting value=None to be safe, though visible=False should imply it.
download_image_btn = gr.DownloadButton(label="Download Image", value=None, visible=False)
download_csv_btn = gr.DownloadButton(label="Download CSV", value=None, visible=False)
# This wrapper function handles the UI updates, which is the robust way to use Gradio.
def run_classification_and_update_ui(sem_image, cluster_thresh, m1_thresh, m2_thresh):
"""
Calls the core logic and then returns updates for the Gradio UI components.
"""
try:
# Call the main processing function
classified_img, img_path, csv_path = damage_classification(sem_image, cluster_thresh, m1_thresh, m2_thresh)
# Return the results in the correct order to update the output components.
# Use gr.update to change properties of a component, like visibility and value.
return (
classified_img,
gr.update(value=img_path, visible=True),
gr.update(value=csv_path, visible=True)
)
except Exception as e:
# Catch any error during classification and display it gracefully
logging.error(f"Error during classification: {e}")
gr.Warning(f"An error occurred: {e}")
# Keep download buttons hidden on error and clear image
return (
None, # Clear the image on error
gr.update(visible=False),
gr.update(visible=False)
)
# Connect the button's click event to the wrapper function
classify_btn.click(
fn=run_classification_and_update_ui,
inputs=[
image_input,
cluster_threshold_input,
model1_threshold_input,
model2_threshold_input
],
outputs=[
output_image,
download_image_btn,
download_csv_btn
],
)
if __name__ == "__main__":
app.launch()