Spaces:
Sleeping
Sleeping
Update tongue_model.py
Browse files- tongue_model.py +3 -2
tongue_model.py
CHANGED
|
@@ -53,10 +53,11 @@ class TongueArcResNet(nn.Module):
|
|
| 53 |
|
| 54 |
# --- 2. 瀹氱京闋愯檿鐞嗚垏鎺ㄨ珫椤炲垾 ---
|
| 55 |
class TongueModelWrapper:
|
| 56 |
-
def __init__(self, model_path, num_classes=
|
| 57 |
self.device = torch.device("cpu")
|
| 58 |
self.model = TongueArcResNet(num_classes=num_classes)
|
| 59 |
-
|
|
|
|
| 60 |
self.model.eval()
|
| 61 |
self.transform = transforms.Compose([
|
| 62 |
transforms.ToTensor(),
|
|
|
|
| 53 |
|
| 54 |
# --- 2. 瀹氱京闋愯檿鐞嗚垏鎺ㄨ珫椤炲垾 ---
|
| 55 |
class TongueModelWrapper:
|
| 56 |
+
def __init__(self, model_path, num_classes=3):
|
| 57 |
self.device = torch.device("cpu")
|
| 58 |
self.model = TongueArcResNet(num_classes=num_classes)
|
| 59 |
+
torch.serialization.add_safe_globals([np._core.multiarray.scalar])
|
| 60 |
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False))
|
| 61 |
self.model.eval()
|
| 62 |
self.transform = transforms.Compose([
|
| 63 |
transforms.ToTensor(),
|