Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import tensorflow as tf | |
| import os | |
| import requests | |
| import tempfile | |
| import matplotlib.pyplot as plt | |
| from tensorflow.keras.models import Sequential | |
| from tensorflow.keras.layers import Flatten, Dense, Reshape | |
| from tensorflow.keras.losses import SparseCategoricalCrossentropy | |
| from io import StringIO | |
| import datetime | |
| import tensorboard | |
| from tensorboard import program | |
| try: | |
| # Check if a GPU is available | |
| gpu = len(tf.config.list_physical_devices('GPU')) > 0 | |
| if gpu: | |
| st.write("GPU is available!") # Inform the user | |
| # Set TensorFlow to use the GPU if available (optional, usually automatic) | |
| # You can specify which GPU if you have multiple: | |
| # tf.config.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU') # Use the first GPU | |
| # or | |
| # tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) # Memory growth for the first GPU | |
| # or | |
| # strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) # Use multiple GPUs | |
| else: | |
| st.write("GPU is not available. Using CPU.") | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU usage (optional) | |
| except RuntimeError as e: | |
| st.write(f"Error checking GPU: {e}") | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU usage if there is a runtime error | |
| def run_tensorboard(log_dir): | |
| # Start TensorBoard | |
| tb = program.TensorBoard() | |
| tb.configure(argv=[None, '--logdir', log_dir]) | |
| url = tb.launch() | |
| return url | |
| # Constants for dataset information | |
| TRAIN_FILE = "train_images.tfrecords" | |
| VAL_FILE = "val_images.tfrecords" | |
| TRAIN_URL = "https://huggingface.co/datasets/louiecerv/cardiac_images/resolve/main/train_images.tfrecords" | |
| VAL_URL = "https://huggingface.co/datasets/louiecerv/cardiac_images/resolve/main/val_images.tfrecords" | |
| # Use a persistent temp directory | |
| tmpdir = tempfile.gettempdir() | |
| # Function to download a file with progress display | |
| def download_file(url, local_filename, target_dir): | |
| os.makedirs(target_dir, exist_ok=True) | |
| filepath = os.path.join(target_dir, local_filename) | |
| if os.path.exists(filepath): | |
| st.write(f"File already exists: {filepath}") | |
| return filepath | |
| with requests.get(url, stream=True) as r: | |
| r.raise_for_status() | |
| total_size = int(r.headers.get('content-length', 0)) | |
| progress_bar = st.empty() # Create a placeholder | |
| with open(filepath, 'wb') as f: | |
| downloaded_size = 0 | |
| for chunk in r.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| downloaded_size += len(chunk) | |
| progress_percent = int(downloaded_size / total_size * 100) | |
| progress_bar.progress(progress_percent, text=f"Downloading {local_filename}...") | |
| return filepath | |
| # Download only if files are missing | |
| train_file_path = download_file(TRAIN_URL, TRAIN_FILE, tmpdir) | |
| val_file_path = download_file(VAL_URL, VAL_FILE, tmpdir) | |
| # Dictionary describing the fields stored in TFRecord | |
| image_feature_description = { | |
| 'height': tf.io.FixedLenFeature([], tf.int64), | |
| 'width': tf.io.FixedLenFeature([], tf.int64), | |
| 'depth': tf.io.FixedLenFeature([], tf.int64), | |
| 'name': tf.io.FixedLenFeature([], tf.string), | |
| 'image_raw': tf.io.FixedLenFeature([], tf.string), | |
| 'label_raw': tf.io.FixedLenFeature([], tf.string), | |
| } | |
| # Helper function to parse the image and label data from TFRecord | |
| def _parse_image_function(example_proto): | |
| return tf.io.parse_single_example(example_proto, image_feature_description) | |
| # Function to read and decode an example from the dataset | |
| def read_and_decode(example): | |
| image_raw = tf.io.decode_raw(example['image_raw'], tf.int64) | |
| image_raw.set_shape([65536]) | |
| image = tf.reshape(image_raw, [256, 256, 1]) | |
| image = tf.cast(image, tf.float32) * (1. / 1024) | |
| label_raw = tf.io.decode_raw(example['label_raw'], tf.uint8) | |
| label_raw.set_shape([65536]) | |
| label = tf.reshape(label_raw, [256, 256, 1]) | |
| return image, label | |
| # Load and parse datasets | |
| raw_training_dataset = tf.data.TFRecordDataset(train_file_path) | |
| raw_val_dataset = tf.data.TFRecordDataset(val_file_path) | |
| parsed_training_dataset = raw_training_dataset.map(_parse_image_function) | |
| parsed_val_dataset = raw_val_dataset.map(_parse_image_function) | |
| # Prepare datasets | |
| tf_autotune = tf.data.experimental.AUTOTUNE | |
| train = parsed_training_dataset.map(read_and_decode, num_parallel_calls=tf_autotune) | |
| val = parsed_val_dataset.map(read_and_decode) | |
| BUFFER_SIZE = 10 | |
| BATCH_SIZE = 1 | |
| train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() | |
| train_dataset = train_dataset.prefetch(buffer_size=tf_autotune) | |
| test_dataset = val.batch(BATCH_SIZE) | |
| st.write(train_dataset) | |
| # function to take a prediction from the model and output an image for display | |
| def create_mask(pred_mask): | |
| pred_mask = tf.argmax(pred_mask, axis=-1) | |
| pred_mask = pred_mask[..., tf.newaxis] | |
| return pred_mask[0] | |
| def display(display_list): | |
| fig = plt.figure(figsize=(10, 10)) | |
| title = ['Input Image', 'Label', 'Prediction'] # Updated title list | |
| for i in range(len(display_list)): | |
| ax = fig.add_subplot(1, len(display_list), i + 1) | |
| display_resized = tf.reshape(display_list[i], [256, 256]) | |
| ax.set_title(title[i]) # No longer out of range | |
| ax.imshow(display_resized, cmap='gray') | |
| ax.axis('off') | |
| st.pyplot(fig) | |
| # helper function to show the image, the label and the prediction | |
| def show_predictions(dataset=None, num=1): | |
| if dataset: | |
| for image, label in dataset.take(num): | |
| pred_mask = model.predict(image) | |
| display([image[0], label[0], create_mask(pred_mask)]) | |
| else: | |
| prediction = create_mask(model.predict(sample_image[tf.newaxis, ...])) | |
| display([sample_image, sample_label, prediction]) | |
| # define a callback that shows image predictions on the test set | |
| class DisplayCallback(tf.keras.callbacks.Callback): | |
| def on_epoch_end(self, epoch, logs=None): | |
| show_predictions() | |
| st.write('\nSample Prediction after epoch {}\n'.format(epoch+1)) | |
| # Streamlit app interface | |
| st.title("Cardiac Images Dataset") | |
| # Display sample images | |
| for image, label in train.take(2): | |
| sample_image, sample_label = image, label | |
| display([sample_image, sample_label]) | |
| tf.keras.backend.clear_session() | |
| # set up the model architecture | |
| model = tf.keras.models.Sequential([ | |
| tf.keras.layers.Input(shape=(256, 256, 1)), # Define input shape | |
| tf.keras.layers.Flatten(), | |
| tf.keras.layers.Dense(64, activation='relu'), | |
| tf.keras.layers.Dense(256*256*2, activation='softmax'), | |
| tf.keras.layers.Reshape((256, 256, 2)) | |
| ]) | |
| # specify how to train the model with algorithm, the loss function and metrics | |
| model.compile( | |
| optimizer='adam', | |
| loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
| metrics=['accuracy']) | |
| # Capture the model summary | |
| model_summary = StringIO() | |
| model.summary(print_fn=lambda x: model_summary.write(x + '\n')) | |
| # Display the model summary in Streamlit | |
| st.markdown(model_summary.getvalue()) | |
| try: | |
| # Save the model plot | |
| plot_filename = "model_plot.png" | |
| tf.keras.utils.plot_model(model, to_file=plot_filename, show_shapes=True) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| # Streamlit App | |
| st.title("Model Architecture") | |
| # Display the model plot | |
| st.image(plot_filename, caption="Neural Network Architecture", use_container_width=True) | |
| # show a predection, as an example | |
| show_predictions(test_dataset) | |
| # setup a tensorboard callback | |
| logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) | |
| tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) | |
| if st.button("Train Model"): | |
| # setup and run the model | |
| EPOCHS = 20 | |
| STEPS_PER_EPOCH = len(list(parsed_training_dataset)) | |
| VALIDATION_STEPS = 26 | |
| model_history = model.fit(train_dataset, epochs=EPOCHS, | |
| steps_per_epoch=STEPS_PER_EPOCH, | |
| validation_steps=VALIDATION_STEPS, | |
| validation_data=test_dataset, | |
| callbacks=[DisplayCallback(), tensorboard_callback]) | |
| # output model statistics | |
| loss = model_history.history['loss'] | |
| val_loss = model_history.history['val_loss'] | |
| accuracy = model_history.history['accuracy'] | |
| val_accuracy = model_history.history['val_accuracy'] | |
| epochs = range(EPOCHS) | |
| st.title('Training and Validation Loss') # Optional title for the Streamlit app | |
| fig, ax = plt.subplots() # Create a figure and an axes object | |
| ax.plot(epochs, loss, 'r', label='Training loss') | |
| ax.plot(epochs, val_loss, 'bo', label='Validation loss') | |
| ax.set_title('Training and Validation Loss') #Set title for the axes | |
| ax.set_xlabel('Epoch') | |
| ax.set_ylabel('Loss Value') | |
| ax.set_ylim([0, 1]) | |
| ax.legend() | |
| st.pyplot(fig) # Display the plot in Streamlit | |
| if st.button("Evaluate Model"): | |
| # Evaluate the model | |
| evaluation_results = model.evaluate(test_dataset, verbose=0) # Set verbose=0 to suppress console output | |
| # Assuming model.metrics_names provides labels for evaluation_results | |
| results_dict = dict(zip(model.metrics_names, evaluation_results)) | |
| st.subheader("Model Evaluation Results") | |
| # Display each metric and its corresponding value | |
| for metric, value in results_dict.items(): | |
| st.write(f"**{metric.capitalize()}:** {value:.4f}") | |
| if st.button("Show TensorBoard"): | |
| # Create a log directory for TensorBoard | |
| log_dir = "logs" | |
| if not os.path.exists(log_dir): | |
| os.makedirs(log_dir) | |
| # Run TensorBoard | |
| url = run_tensorboard(log_dir) | |
| # Display TensorBoard in an iframe | |
| st.markdown(f"<iframe src='{url}' width='100%' height='800'></iframe>", unsafe_allow_html=True) | |
| if st.button("CNN"): | |
| tf.keras.backend.clear_session() | |
| inputs = tf.keras.Input(shape=(256, 256, 1), name="InputLayer") | |
| x = tf.keras.layers.Conv2D(filters=100, kernel_size=5, strides=2, padding="same", | |
| activation="relu", name="Conv1")(inputs) | |
| x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding="same")(x) | |
| x = tf.keras.layers.Conv2D(filters=200, kernel_size=5, strides=2, padding="same", | |
| activation="relu", name="Conv2")(x) | |
| x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding="same")(x) | |
| x = tf.keras.layers.Conv2D(filters=300, kernel_size=3, strides=1, padding="same", | |
| activation="relu", name="Conv3")(x) | |
| x = tf.keras.layers.Conv2D(filters=300, kernel_size=3, strides=1, padding="same", | |
| activation="relu", name="Conv4")(x) | |
| x = tf.keras.layers.Conv2D(filters=2, kernel_size=1, strides=1, padding="same", | |
| activation="relu", name="Conv5")(x) | |
| outputs = tf.keras.layers.Conv2DTranspose(filters=2, kernel_size=31, strides=16, | |
| padding="same", activation="softmax", | |
| name="UpSampling")(x) | |
| model = tf.keras.Model(inputs=inputs, outputs=outputs, name="CNN_Segmentation") | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(), | |
| loss=tf.keras.losses.SparseCategoricalCrossentropy(), | |
| metrics=['accuracy'] | |
| ) | |
| # Capture the model summary | |
| model_summary = StringIO() | |
| model.summary(print_fn=lambda x: model_summary.write(x + '\n')) | |
| # plot the model including the sizes of the model | |
| tf.keras.utils.plot_model(model, show_shapes=True) | |
| # show a predection, as an example | |
| show_predictions(test_dataset) | |
| # Initialize new directories for new task | |
| logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) | |
| tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) | |
| # setup and run the model | |
| EPOCHS = 20 | |
| STEPS_PER_EPOCH = len(list(parsed_training_dataset)) | |
| VALIDATION_STEPS = 26 | |
| model_history = model.fit(train_dataset, epochs=EPOCHS, | |
| steps_per_epoch=STEPS_PER_EPOCH, | |
| validation_steps=VALIDATION_STEPS, | |
| validation_data=test_dataset, | |
| callbacks=[DisplayCallback(), tensorboard_callback]) | |
| # output model statistics | |
| loss = model_history.history['loss'] | |
| val_loss = model_history.history['val_loss'] | |
| accuracy = model_history.history['accuracy'] | |
| val_accuracy = model_history.history['val_accuracy'] | |
| epochs = range(EPOCHS) | |
| st.title('Training and Validation Loss') # Optional title for the Streamlit app | |
| fig, ax = plt.subplots() # Create a figure and an axes object | |
| ax.plot(epochs, loss, 'r', label='Training loss') | |
| ax.plot(epochs, val_loss, 'bo', label='Validation loss') | |
| ax.set_title('Training and Validation Loss') #Set title for the axes | |
| ax.set_xlabel('Epoch') | |
| ax.set_ylabel('Loss Value') | |
| ax.set_ylim([0, 1]) | |
| ax.legend() | |
| st.pyplot(fig) # Display the plot in Streamlit | |