kohido commited on
Commit
bf6ec79
·
1 Parent(s): c97b4da
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -12,13 +12,15 @@ import wandb
12
 
13
  cfg = load_cfg("configs/effb0-base-breakhis.yaml")
14
 
 
 
15
  model = system.SimpleClassificationSystem.load_from_checkpoint(
16
  "runs/03_20_2025_20_19_28/wandb/histopath/03_20_2025_20_19_28/checkpoints/epoch=210-step=41778.ckpt",
17
  torch.device("cpu"),
18
  cfg=cfg.system
19
  )
20
 
21
- model.to("cuda")
22
  model.eval()
23
 
24
  print("Model loaded successfully!")
@@ -34,7 +36,7 @@ def preprocess_image(image: Image):
34
  v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
35
  ])
36
 
37
- return transform(image).unsqueeze(0).to("cuda")
38
 
39
 
40
  def predict(image_path: str):
 
12
 
13
  cfg = load_cfg("configs/effb0-base-breakhis.yaml")
14
 
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
  model = system.SimpleClassificationSystem.load_from_checkpoint(
18
  "runs/03_20_2025_20_19_28/wandb/histopath/03_20_2025_20_19_28/checkpoints/epoch=210-step=41778.ckpt",
19
  torch.device("cpu"),
20
  cfg=cfg.system
21
  )
22
 
23
+ model.to(device)
24
  model.eval()
25
 
26
  print("Model loaded successfully!")
 
36
  v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37
  ])
38
 
39
+ return transform(image).unsqueeze(0).to(device)
40
 
41
 
42
  def predict(image_path: str):