admin commited on
Commit
5cbce3b
·
1 Parent(s): 5ecf3c1
Files changed (2) hide show
  1. README.md +1 -1
  2. model.py +4 -4
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: ☰
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
  app_file: app.py
9
  pinned: true
10
  license: mit
 
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 6.6.0
8
  app_file: app.py
9
  pinned: true
10
  license: mit
model.py CHANGED
@@ -21,10 +21,10 @@ class EvalNet:
21
  self.model = eval("models.%s()" % m_ver)
22
  linear_output = self._set_outsize()
23
  self._set_classifier(cls_num, linear_output)
24
- checkpoint = torch.load(saved_model_path, map_location="cpu")
25
- if torch.cuda.is_available():
26
- checkpoint = torch.load(saved_model_path)
27
-
28
  self.model.load_state_dict(checkpoint, False)
29
  self.model.eval()
30
 
 
21
  self.model = eval("models.%s()" % m_ver)
22
  linear_output = self._set_outsize()
23
  self._set_classifier(cls_num, linear_output)
24
+ checkpoint = torch.load(
25
+ saved_model_path,
26
+ map_location=torch.device("cuda:0") if torch.cuda.is_available() else "cpu",
27
+ )
28
  self.model.load_state_dict(checkpoint, False)
29
  self.model.eval()
30