File size: 3,534 Bytes
06de54b
 
 
 
c003149
33f989c
 
06de54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abfbd7c
06de54b
 
 
 
 
7210fa6
 
 
 
 
 
 
06de54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#Part 2: adapted from our student's submission: https://huggingface.co/spaces/Halima/Homework02_part2
import matplotlib as mpl
import matplotlib.pyplot as plt
import gradio as gr
import os
import cv2
import numpy as np
from scipy.ndimage.interpolation import shift
from scipy import ndimage

import tensorflow as tf

#step 1: load the dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()


#step 2: Define input components on Gradio

input_module1 = gr.inputs.Image(label = "test_image", image_mode='L', shape = (28,28))

input_module2 = gr.inputs.Dropdown(choices=["shift up", "shift down", "shift left", "shift right", "rotate left", "rotate right"], label = "Augmentation")


#step 3: Define output components on Gradio
output_module1 = gr.outputs.Image(label = "Augmented Image")


#step 4: Define functions for Gradio to use
def image_movement(input1, input2):
    
    if input2 == 'shift up':
        digit_image_u = shift(input1, [-5, 0], cval=0)
        plt.imshow(digit_image_u,interpolation="nearest", cmap='gray')
        plt.axis('off')
        plt.savefig('testing.png')

    elif input2 == 'shift down':
        digit_image_d = shift(input1, [+5, 0], cval=0)
        plt.imshow(digit_image_d,interpolation="nearest", cmap='gray')
        plt.axis('off')
        plt.savefig('testing.png')

    elif input2 == 'shift left':
        digit_image_l = shift(input1, [0, -5], cval=0)
        plt.imshow(digit_image_l,interpolation="nearest", cmap='gray')
        plt.axis('off')
        plt.savefig('testing.png')

    elif input2 == 'shift right':
        digit_image_r = shift(input1, [0, +5], cval=0)
        plt.imshow(digit_image_r,interpolation="nearest", cmap='gray')
        plt.axis('off')
        plt.savefig('testing.png')

    elif input2 == 'rotate left':
        rotate_image_right = ndimage.rotate(input1, 20, reshape=False)
        plt.imshow(rotate_image_right,interpolation="nearest", cmap='gray')
        plt.axis('off')
        plt.savefig('testing.png')

    else:
        rotate_image_left = ndimage.rotate(input1, -20, reshape=False)
        plt.imshow(rotate_image_left,interpolation="nearest", cmap='gray')
        plt.axis('off')
        plt.savefig('testing.png')

    return 'testing.png'
    
    
 


# Step 5: generate several example cases
def get_sample_images(num_images):
    sample_images = []
    for i in range(num_images):
      test_feature = x_test[i]
      test_feature_2d =test_feature.reshape(28,28)
    
      # Make it unsigned integers:
      data = test_feature_2d.astype(np.uint8)

      outdir =  "images_folder"
      img_path = os.path.join(outdir, 'local_%05d.png' % (i,))
      if not os.path.exists(outdir):
         os.mkdir(outdir)
      cv2.imwrite(img_path, data)

      sample_images.append([img_path,np.random.choice(['shift up', 'shift down', 'shift left', 'shift right', 'rotate left', 'rotate right'])])   # ["image path", "K"]
    return sample_images
sample_images = get_sample_images(10)


            
 # Step 6: Put all three component together into the gradio's interface function 
gr.Interface(fn=image_movement, 
             inputs=[input_module1, input_module2], 
             outputs=[output_module1],
             examples_per_page = 2,
             examples = sample_images, 
             title="CSCI4750/5750 Homework02-Part II: Image Rotation", 
             description= "Click examples below for a quick demo",
             theme = 'huggingface',
             layout = 'vertical'
            ).launch(debug = True)