Spaces:
Sleeping
Sleeping
Commit ·
d9df2d4
1
Parent(s): e068adf
change path
Browse files
app.py
CHANGED
|
@@ -22,6 +22,7 @@ import torch.nn as nn
|
|
| 22 |
import torchvision.models as models
|
| 23 |
from torchvision import transforms
|
| 24 |
from torchvision.models import EfficientNet_B0_Weights
|
|
|
|
| 25 |
|
| 26 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
|
|
@@ -97,8 +98,9 @@ class MultimodalNet(nn.Module):
|
|
| 97 |
FORCE_TAB_DIM = 38
|
| 98 |
FORCE_NUM_CLASSES = None # ตั้งเป็นเลขจริงถ้าอยากบังคับ, หรือปล่อย None ให้ดึงจาก ckpt/ดีฟอลต์
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
# อย่าอ่าน tab_in_dim จาก ckpt แล้วเผลอได้ 14 มาอีก
|
| 104 |
tab_in_dim = FORCE_TAB_DIM
|
|
|
|
| 22 |
import torchvision.models as models
|
| 23 |
from torchvision import transforms
|
| 24 |
from torchvision.models import EfficientNet_B0_Weights
|
| 25 |
+
from pathlib import Path
|
| 26 |
|
| 27 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
|
|
|
|
| 98 |
FORCE_TAB_DIM = 38
|
| 99 |
FORCE_NUM_CLASSES = None # ตั้งเป็นเลขจริงถ้าอยากบังคับ, หรือปล่อย None ให้ดึงจาก ckpt/ดีฟอลต์
|
| 100 |
|
| 101 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 102 |
+
CKPT_PATH = BASE_DIR / "best_multimodal.pt" # หรือ BASE_DIR / "models" / "best_multimodal.pt"
|
| 103 |
+
ckpt = torch.load(str(CKPT_PATH), map_location=DEVICE)
|
| 104 |
|
| 105 |
# อย่าอ่าน tab_in_dim จาก ckpt แล้วเผลอได้ 14 มาอีก
|
| 106 |
tab_in_dim = FORCE_TAB_DIM
|