jiehou's picture
Update app.py
5e916a5
#Part 2: adapted from our student's submission: https://huggingface.co/spaces/Halima/Homework02_part2
import matplotlib as mpl
import matplotlib.pyplot as plt
import gradio as gr
import os
import cv2
import numpy as np
from scipy.ndimage.interpolation import shift
from scipy import ndimage
import tensorflow as tf
#step 1: load the dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
#step 2: Define input components on Gradio
input_module1 = gr.inputs.Image(label = "test_image", image_mode='L', shape = (28,28))
input_module2 = gr.inputs.Dropdown(choices=["shift up", "shift down", "shift left", "shift right", "rotate left", "rotate right"], label = "Augmentation")
#step 3: Define output components on Gradio
output_module1 = gr.outputs.Image(label = "Augmented Image")
#step 4: Define functions for Gradio to use
def image_movement(input1, input2):
if input2 == 'shift up':
digit_image_u = shift(input1, [-5, 0], cval=0)
plt.imshow(digit_image_u,interpolation="nearest", cmap='gray')
plt.axis('off')
plt.savefig('testing.png')
elif input2 == 'shift down':
digit_image_d = shift(input1, [+5, 0], cval=0)
plt.imshow(digit_image_d,interpolation="nearest", cmap='gray')
plt.axis('off')
plt.savefig('testing.png')
elif input2 == 'shift left':
digit_image_l = shift(input1, [0, -5], cval=0)
plt.imshow(digit_image_l,interpolation="nearest", cmap='gray')
plt.axis('off')
plt.savefig('testing.png')
elif input2 == 'shift right':
digit_image_r = shift(input1, [0, +5], cval=0)
plt.imshow(digit_image_r,interpolation="nearest", cmap='gray')
plt.axis('off')
plt.savefig('testing.png')
elif input2 == 'rotate left':
rotate_image_right = ndimage.rotate(input1, 20, reshape=False)
plt.imshow(rotate_image_right,interpolation="nearest", cmap='gray')
plt.axis('off')
plt.savefig('testing.png')
else:
rotate_image_left = ndimage.rotate(input1, -20, reshape=False)
plt.imshow(rotate_image_left,interpolation="nearest", cmap='gray')
plt.axis('off')
plt.savefig('testing.png')
return 'testing.png'
# Step 5: generate several example cases
def get_sample_images(num_images):
sample_images = []
for i in range(num_images):
test_feature = x_test[i]
test_feature_2d =test_feature.reshape(28,28)
# Make it unsigned integers:
data = test_feature_2d.astype(np.uint8)
outdir = "images_folder"
img_path = os.path.join(outdir, 'local_%05d.png' % (i,))
if not os.path.exists(outdir):
os.mkdir(outdir)
cv2.imwrite(img_path, data)
sample_images.append([img_path,np.random.choice(['shift up', 'shift down', 'shift left', 'shift right', 'rotate left', 'rotate right'])]) # ["image path", "K"]
return sample_images
sample_images = get_sample_images(10)
# Step 6: Put all three component together into the gradio's interface function
gr.Interface(fn=image_movement,
inputs=[input_module1, input_module2],
outputs=[output_module1],
examples_per_page = 2,
examples = sample_images,
title="CSCI4750/5750 Homework02-Part II: Image Rotation",
description= "Click examples below for a quick demo",
theme = 'huggingface',
layout = 'vertical'
).launch(debug = True)