File size: 2,231 Bytes
a783ca8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import argparse
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import os
import numpy as np
import uuid
def main(image_dir, model_path, gradients_save_path):
# Load all images from directory into a list
target_size = (32, 32)
images_list = []
for img_path in os.listdir(image_dir):
full_path = os.path.join(image_dir, img_path)
if os.path.isfile(full_path):
image = load_img(full_path, target_size=(64, 64), color_mode='grayscale')
image_arr = img_to_array(image) / 255.0
images_list.append(image_arr)
data = np.array(images_list)
# Load the model
model = tf.keras.models.load_model(model_path)
model.summary()
# Check if data is available and is not empty
if data is not None and len(data) > 0:
pseudo_labels = model.predict(data)
else:
print("The data variable is empty!")
def compute_gradients(model, data, labels):
with tf.GradientTape() as tape:
predictions = model(data, training=True)
loss = tf.keras.losses.categorical_crossentropy(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
return gradients
gradients = compute_gradients(model, data, pseudo_labels)
# Serialize gradients and save to files
os.makedirs(gradients_save_path, exist_ok=True)
for grad in gradients:
gradient_id = uuid.uuid4()
path = os.path.join(gradients_save_path, f'gradient_{gradient_id}.npy')
np.save(path, grad.numpy())
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Load images, use model to predict and compute gradients.')
parser.add_argument('--image_dir', type=str, default='/content/brain_tumor_dataset', help='Directory where images are located.')
parser.add_argument('--model_path', type=str, default='/content/brain_tumor_classifier.h5', help='Path to the model file.')
parser.add_argument('--gradients_save_path', type=str, default='saved_gradients', help='Directory where gradients will be saved.')
args = parser.parse_args()
main(args.image_dir, args.model_path, args.gradients_save_path)
|