jiehou commited on
Commit
06de54b
·
1 Parent(s): b10357c

Create new file

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Part 2: adapted from our student's submission: https://huggingface.co/spaces/Halima/Homework02_part2
2
+ import matplotlib as mpl
3
+ import matplotlib.pyplot as plt
4
+ import gradio as gr
5
+ from scipy.ndimage.interpolation import shift
6
+ from scipy import ndimage
7
+
8
+ import tensorflow as tf
9
+
10
+ #step 1: load the dataset
11
+ (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
12
+
13
+
14
+ #step 2: Define input components on Gradio
15
+
16
+ input_module1 = gr.inputs.Image(label = "test_image", image_mode='L', shape = (28,28))
17
+
18
+ input_module2 = gr.inputs.Dropdown(choices=["shift up", "shift down", "shift left", "shift right", "rotate left", "rotate right"], label = "Augmentation")
19
+
20
+
21
+ #step 3: Define output components on Gradio
22
+ output_module1 = gr.outputs.Image(label = "Augmented Image")
23
+
24
+
25
+ #step 4: Define functions for Gradio to use
26
+ def image_movement(input1, input2):
27
+
28
+ if input2 == 'shift up':
29
+ digit_image_u = shift(input1, [-5, 0], cval=0)
30
+ plt.imshow(digit_image_u,interpolation="nearest", cmap='gray')
31
+ plt.axis('off')
32
+ plt.savefig('testing.png')
33
+
34
+ elif input2 == 'shift down':
35
+ digit_image_d = shift(input1, [+5, 0], cval=0)
36
+ plt.imshow(digit_image_d,interpolation="nearest", cmap='gray')
37
+ plt.axis('off')
38
+ plt.savefig('testing.png')
39
+
40
+ elif input2 == 'shift left':
41
+ digit_image_l = shift(input1, [0, -5], cval=0)
42
+ plt.imshow(digit_image_l,interpolation="nearest", cmap='gray')
43
+ plt.axis('off')
44
+ plt.savefig('testing.png')
45
+
46
+ elif input2 == 'shift right':
47
+ digit_image_r = shift(input1, [0, +5], cval=0)
48
+ plt.imshow(digit_image_r,interpolation="nearest", cmap='gray')
49
+ plt.axis('off')
50
+ plt.savefig('testing.png')
51
+
52
+ elif input2 == 'rotate left':
53
+ rotate_image_right = ndimage.rotate(input1, 20, reshape=False)
54
+ plt.imshow(rotate_image_right,interpolation="nearest", cmap='gray')
55
+ plt.axis('off')
56
+ plt.savefig('testing.png')
57
+
58
+ else:
59
+ rotate_image_left = ndimage.rotate(input1, -20, reshape=False)
60
+ plt.imshow(rotate_image_left,interpolation="nearest", cmap='gray')
61
+ plt.axis('off')
62
+ plt.savefig('testing.png')
63
+
64
+ return 'testing.png'
65
+
66
+
67
+
68
+
69
+
70
+ # Step 5: generate several example cases
71
+ def get_sample_images(num_images):
72
+ sample_images = []
73
+ for i in range(num_images):
74
+ test_feature = x_test[i]
75
+ test_feature_2d =test_feature.reshape(28,28)
76
+
77
+ # Make it unsigned integers:
78
+ data = test_feature_2d.astype(np.uint8)
79
+
80
+ sample_images.append([data,int(np.random.choice(['shift up', 'shift down', 'shift left', 'shift right', 'rotate left', 'rotate right']))]) # ["image path", "K"]
81
+ return sample_images
82
+
83
+ sample_images = get_sample_images(10)
84
+
85
+
86
+
87
+ # Step 6: Put all three component together into the gradio's interface function
88
+ gr.Interface(fn=image_movement,
89
+ inputs=[input_module1, input_module2],
90
+ outputs=[output_module1],
91
+ examples_per_page = 2,
92
+ examples = sample_images,
93
+ title="CSCI4750/5750 Homework02-Part II: Image Rotation",
94
+ description= "Click examples below for a quick demo",
95
+ theme = 'huggingface',
96
+ layout = 'vertical'
97
+
98
+ ).launch(debug = True)