Digambar29 commited on
Commit
8585b93
·
1 Parent(s): 40cf0fb

Updated inference.py to clear the model not found runtime error

Browse files
Files changed (1) hide show
  1. model/inference.py +4 -9
model/inference.py CHANGED
@@ -3,18 +3,17 @@ import torch
3
  import torch.nn as nn
4
  from torchvision import models, transforms
5
  from PIL import Image
 
6
 
7
  # device
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
- # load checkpoint
11
- from pathlib import Path
12
-
13
  BASE_DIR = Path(__file__).resolve().parent
14
  MODEL_PATH = BASE_DIR / "emotion_recognition_model.pth"
15
 
 
16
  checkpoint = torch.load(MODEL_PATH, map_location=device)
17
-
18
  classes = checkpoint["classes"]
19
 
20
  # recreate model
@@ -25,7 +24,7 @@ model.load_state_dict(checkpoint["model_state"])
25
  model.to(device)
26
  model.eval()
27
 
28
- # preprocessing (SAME as training)
29
  transform = transforms.Compose([
30
  transforms.Resize((224, 224)),
31
  transforms.ToTensor(),
@@ -41,7 +40,3 @@ def predict(pil_image: Image.Image):
41
  logits = model(x)
42
  idx = logits.argmax(dim=1).item()
43
  return classes[idx]
44
-
45
- if __name__ == "__main__":
46
- img = Image.open("Image.jpg") # any face image
47
- print(predict(img))
 
3
  import torch.nn as nn
4
  from torchvision import models, transforms
5
  from PIL import Image
6
+ from pathlib import Path
7
 
8
  # device
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
+ # resolve model path safely
 
 
12
  BASE_DIR = Path(__file__).resolve().parent
13
  MODEL_PATH = BASE_DIR / "emotion_recognition_model.pth"
14
 
15
+ # load checkpoint
16
  checkpoint = torch.load(MODEL_PATH, map_location=device)
 
17
  classes = checkpoint["classes"]
18
 
19
  # recreate model
 
24
  model.to(device)
25
  model.eval()
26
 
27
+ # preprocessing
28
  transform = transforms.Compose([
29
  transforms.Resize((224, 224)),
30
  transforms.ToTensor(),
 
40
  logits = model(x)
41
  idx = logits.argmax(dim=1).item()
42
  return classes[idx]