Harry2687 commited on
Commit
8a7e77c
·
1 Parent(s): 2b8ab40

updated layout and added example image

Browse files
Files changed (2) hide show
  1. app.py +43 -12
  2. gender_cnn/predict.py +12 -7
app.py CHANGED
@@ -2,18 +2,40 @@ from shiny import App, reactive, render, ui
2
  from shiny.types import ImgData
3
 
4
  from gender_cnn.predict import predict_gender
 
5
 
6
  app_ui = ui.page_fillable(
7
  ui.panel_title('Gender Classifier'),
8
- ui.card(
9
- ui.card_header('Input'),
10
- ui.input_file('image', 'Upload image', accept=['.png', '.jpg', '.jpeg']),
11
- ui.output_image('show_image')
12
- ),
13
- ui.card(
14
- ui.card_header('Predict'),
15
- ui.input_action_button('predict_gender', 'Make Prediction'),
16
- ui.output_text('prediction')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
  )
19
 
@@ -24,7 +46,13 @@ def server(input, output, session):
24
  return None
25
 
26
  image_path = input.image()[0]['datapath']
27
- img: ImgData = {'src': image_path, 'height': '300px', 'width': '300px'}
 
 
 
 
 
 
28
  return img
29
 
30
  @render.text
@@ -37,8 +65,11 @@ def server(input, output, session):
37
  output = predict_gender(image_path)
38
  prediction = output['prediction']
39
  weighting = output['weighting']
40
- device = output['device']
41
 
42
- return f'Prediction: {prediction}. Weighting: {str(round(weighting, 2))}. Device: {device}.'
 
 
 
 
43
 
44
  app = App(app_ui, server)
 
2
  from shiny.types import ImgData
3
 
4
  from gender_cnn.predict import predict_gender
5
+ from gender_cnn.predict import get_backend
6
 
7
  app_ui = ui.page_fillable(
8
  ui.panel_title('Gender Classifier'),
9
+ ui.output_text('show_backend'),
10
+ ui.navset_pill_list(
11
+ ui.nav_panel(
12
+ "Input and Prediction",
13
+ ui.layout_columns(
14
+ ui.card(
15
+ ui.card_header('Input'),
16
+ ui.input_file('image', 'Upload image', accept=['.png', '.jpg', '.jpeg'])
17
+ ),
18
+ ui.card(
19
+ ui.card_header('Example Image'),
20
+ ui.output_image('show_example_image', fill=True)
21
+ )
22
+ ),
23
+ ui.card(
24
+ ui.card_header('Image'),
25
+ ui.output_image('show_image', fill=True)
26
+ ),
27
+ ui.layout_columns(
28
+ ui.card(
29
+ ui.card_header('Predict'),
30
+ ui.input_action_button('predict_gender', 'Make Prediction')
31
+ ),
32
+ ui.card(
33
+ ui.card_header('Prediction'),
34
+ ui.output_text('prediction')
35
+ )
36
+ )
37
+ ),
38
+ widths=(3, 9)
39
  )
40
  )
41
 
 
46
  return None
47
 
48
  image_path = input.image()[0]['datapath']
49
+ img: ImgData = {'src': image_path, 'height': '100%'}
50
+ return img
51
+
52
+ @render.image
53
+ def show_example_image():
54
+ image_path = 'images/Male/kratos.png'
55
+ img: ImgData = {'src': image_path, 'height': '100%'}
56
  return img
57
 
58
  @render.text
 
65
  output = predict_gender(image_path)
66
  prediction = output['prediction']
67
  weighting = output['weighting']
 
68
 
69
+ return f'Prediction: {prediction}. Weighting: {str(round(weighting, 2))}.'
70
+
71
+ @render.text
72
+ def show_backend():
73
+ return f'Using device: {get_backend()[1]}.'
74
 
75
  app = App(app_ui, server)
gender_cnn/predict.py CHANGED
@@ -3,13 +3,7 @@ import torchvision.transforms as transforms
3
  from PIL import Image
4
  from .model import resnetModel_128
5
 
6
- def predict_gender(image_path: str):
7
- # Constants
8
- imsize = 128
9
- classes = ('Female', 'Male')
10
- model_name = 'resnetModel_128_epoch_2.pt'
11
-
12
- # Set Backend
13
  if torch.backends.mps.is_available():
14
  device = torch.device('mps')
15
  device_name = 'Apple Silicon GPU'
@@ -20,6 +14,17 @@ def predict_gender(image_path: str):
20
  device = torch.device('cpu')
21
  device_name = 'CPU'
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Init model
24
  resnet = resnetModel_128().to(device)
25
  resnet.load_state_dict(torch.load(model_name, map_location=device))
 
3
  from PIL import Image
4
  from .model import resnetModel_128
5
 
6
+ def get_backend():
 
 
 
 
 
 
7
  if torch.backends.mps.is_available():
8
  device = torch.device('mps')
9
  device_name = 'Apple Silicon GPU'
 
14
  device = torch.device('cpu')
15
  device_name = 'CPU'
16
 
17
+ return [device, device_name]
18
+
19
+ def predict_gender(image_path: str):
20
+ # Constants
21
+ imsize = 128
22
+ classes = ('Female', 'Male')
23
+ model_name = 'resnetModel_128_epoch_2.pt'
24
+
25
+ # Set Backend
26
+ device, device_name = get_backend()
27
+
28
  # Init model
29
  resnet = resnetModel_128().to(device)
30
  resnet.load_state_dict(torch.load(model_name, map_location=device))