PinHsuan commited on
Commit
ed6b96b
verified
1 Parent(s): 94c7c73

Update tongue_model.py

Browse files
Files changed (1) hide show
  1. tongue_model.py +3 -2
tongue_model.py CHANGED
@@ -53,10 +53,11 @@ class TongueArcResNet(nn.Module):
53
 
54
  # --- 2. 瀹氱京闋愯檿鐞嗚垏鎺ㄨ珫椤炲垾 ---
55
  class TongueModelWrapper:
56
- def __init__(self, model_path, num_classes=2):
57
  self.device = torch.device("cpu")
58
  self.model = TongueArcResNet(num_classes=num_classes)
59
- self.model.load_state_dict(torch.load(model_path, map_location=self.device))
 
60
  self.model.eval()
61
  self.transform = transforms.Compose([
62
  transforms.ToTensor(),
 
53
 
54
  # --- 2. 瀹氱京闋愯檿鐞嗚垏鎺ㄨ珫椤炲垾 ---
55
  class TongueModelWrapper:
56
+ def __init__(self, model_path, num_classes=3):
57
  self.device = torch.device("cpu")
58
  self.model = TongueArcResNet(num_classes=num_classes)
59
+ torch.serialization.add_safe_globals([np._core.multiarray.scalar])
60
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False))
61
  self.model.eval()
62
  self.transform = transforms.Compose([
63
  transforms.ToTensor(),