amosfang commited on
Commit
fbba5da
·
verified ·
1 Parent(s): f3d3023

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -24,9 +24,9 @@ def resize_image(image, input_shape=(224, 224, 3)):
24
 
25
  return image_resized
26
 
27
- def load_model():
28
  model_dir = snapshot_download(REPO_ID)
29
- # saved_model_dir = os.path.join(download_dir, "saved_model")
30
  unet_model = load_model(model_dir)
31
  return unet_model
32
 
@@ -41,9 +41,9 @@ def ensemble_predict(X_array):
41
 
42
  X_array = np.expand_dims(X_array, axis=0)
43
 
44
- unet_model = load_model('REPO_ID/train_2024-02-14 11-20-17/base_u_net.0098-acc-0.75-val_acc-0.74-loss-0.79.h5')
45
- vgg16_model = load_model('REPO_ID/vgg16_u_net.0092-acc-0.74-val_acc-0.74-loss-0.82.h5')
46
- resnet50_model = load_model('REPO_ID/resnet50_u_net.0095-acc-0.79-val_acc-0.76-loss-0.72.h5')
47
 
48
  pred_y_unet = unet_model.predict(X_array)
49
  pred_y_vgg16 = vgg16_model.predict(X_array)
 
24
 
25
  return image_resized
26
 
27
+ def load_model_file(filename):
28
  model_dir = snapshot_download(REPO_ID)
29
+ saved_model_dir = os.path.join(download_dir, filename)
30
  unet_model = load_model(model_dir)
31
  return unet_model
32
 
 
41
 
42
  X_array = np.expand_dims(X_array, axis=0)
43
 
44
+ unet_model = load_model_file('base_u_net.0098-acc-0.75-val_acc-0.74-loss-0.79.h5')
45
+ vgg16_model = load_model_file('vgg16_u_net.0092-acc-0.74-val_acc-0.74-loss-0.82.h5')
46
+ resnet50_model = load_model_file('resnet50_u_net.0095-acc-0.79-val_acc-0.76-loss-0.72.h5')
47
 
48
  pred_y_unet = unet_model.predict(X_array)
49
  pred_y_vgg16 = vgg16_model.predict(X_array)