Pushpesh commited on
Commit
903bf4c
·
1 Parent(s): 00bfb5f

Added model

Browse files
Files changed (3) hide show
  1. app.py +3 -1
  2. model/__init__.py +0 -0
  3. model/model.py +9 -0
app.py CHANGED
@@ -3,9 +3,11 @@ import torch
3
  from PIL import Image
4
  import numpy as np
5
  from app.utils import recover_light_sources
 
6
 
7
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
- model=torch.load('model/model_epoch_49.pth',map_location=device)
 
9
 
10
  def evaluate(model,image):
11
  model.eval()
 
3
  from PIL import Image
4
  import numpy as np
5
  from app.utils import recover_light_sources
6
+ from model.model import model
7
 
8
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ chk=torch.load('model/model_epoch_49.pth',map_location=device)
10
+ model.load_state_dict(chk['model_state_dict'])
11
 
12
  def evaluate(model,image):
13
  model.eval()
model/__init__.py ADDED
File without changes
model/model.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import segmentation_models_pytorch as smp
2
+
3
+ model = smp.Unet(
4
+ encoder_name="mit_b1",
5
+ encoder_weights="imagenet",
6
+ in_channels=3,
7
+ classes=3,
8
+ activation=None
9
+ )