Boonyaratt commited on
Commit
d9df2d4
·
1 Parent(s): e068adf

change path

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -22,6 +22,7 @@ import torch.nn as nn
22
  import torchvision.models as models
23
  from torchvision import transforms
24
  from torchvision.models import EfficientNet_B0_Weights
 
25
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
 
@@ -97,8 +98,9 @@ class MultimodalNet(nn.Module):
97
  FORCE_TAB_DIM = 38
98
  FORCE_NUM_CLASSES = None # ตั้งเป็นเลขจริงถ้าอยากบังคับ, หรือปล่อย None ให้ดึงจาก ckpt/ดีฟอลต์
99
 
100
- CKPT_PATH = "/content/best_multimodal.pt" # <-- แก้เป็น path ของคุณ
101
- ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
 
102
 
103
  # อย่าอ่าน tab_in_dim จาก ckpt แล้วเผลอได้ 14 มาอีก
104
  tab_in_dim = FORCE_TAB_DIM
 
22
  import torchvision.models as models
23
  from torchvision import transforms
24
  from torchvision.models import EfficientNet_B0_Weights
25
+ from pathlib import Path
26
 
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
 
 
98
  FORCE_TAB_DIM = 38
99
  FORCE_NUM_CLASSES = None # ตั้งเป็นเลขจริงถ้าอยากบังคับ, หรือปล่อย None ให้ดึงจาก ckpt/ดีฟอลต์
100
 
101
+ BASE_DIR = Path(__file__).resolve().parent
102
+ CKPT_PATH = BASE_DIR / "best_multimodal.pt" # หรือ BASE_DIR / "models" / "best_multimodal.pt"
103
+ ckpt = torch.load(str(CKPT_PATH), map_location=DEVICE)
104
 
105
  # อย่าอ่าน tab_in_dim จาก ckpt แล้วเผลอได้ 14 มาอีก
106
  tab_in_dim = FORCE_TAB_DIM