olatte commited on
Commit
7e5edc4
·
verified ·
1 Parent(s): 41b512d

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +10 -2
prediction.py CHANGED
@@ -1,13 +1,19 @@
 
1
  from pathlib import Path
2
  import numpy as np
3
  import torch
4
  from transformers import SamModel, SamConfig, SamProcessor
5
 
 
 
 
 
6
  # Constants
7
  MODEL_DIR = Path(__file__).parent
8
  my_SAM_model = MODEL_DIR / "www/sidewalkSAM.pth"
9
 
10
  def load_model_and_processor(model_path, config_path):
 
11
  model = SamModel(config=SamConfig.from_pretrained(config_path))
12
  processor = SamProcessor.from_pretrained(config_path)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -17,13 +23,14 @@ def load_model_and_processor(model_path, config_path):
17
  model.load_state_dict(torch.load(model_path, map_location=device))
18
  model.to(device)
19
  except Exception as e:
20
- print(f"Error loading model: {e}")
21
  exit(1)
22
 
23
- print("Model is loaded.")
24
  return model, processor, device
25
 
26
  def predict_bbox(model, image, processor, device, prediction_threshold=0.5):
 
27
  prompt = [0, 0, image.width-5, image.height-5]
28
  inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt")
29
  inputs = {k: v.to(device) for k, v in inputs.items()}
@@ -37,6 +44,7 @@ def predict_bbox(model, image, processor, device, prediction_threshold=0.5):
37
  medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
38
  mask = (medsam_seg_prob > prediction_threshold).astype(np.uint8)
39
 
 
40
  return mask
41
 
42
  def get_sidewalk_prediction(image, model, processor, device, threshold=0.7):
 
1
+ import logging
2
  from pathlib import Path
3
  import numpy as np
4
  import torch
5
  from transformers import SamModel, SamConfig, SamProcessor
6
 
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
  # Constants
12
  MODEL_DIR = Path(__file__).parent
13
  my_SAM_model = MODEL_DIR / "www/sidewalkSAM.pth"
14
 
15
  def load_model_and_processor(model_path, config_path):
16
+ logger.info("Loading model and processor...")
17
  model = SamModel(config=SamConfig.from_pretrained(config_path))
18
  processor = SamProcessor.from_pretrained(config_path)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
23
  model.load_state_dict(torch.load(model_path, map_location=device))
24
  model.to(device)
25
  except Exception as e:
26
+ logger.error(f"Error loading model: {e}")
27
  exit(1)
28
 
29
+ logger.info("Model and processor loaded.")
30
  return model, processor, device
31
 
32
  def predict_bbox(model, image, processor, device, prediction_threshold=0.5):
33
+ logger.info("Starting prediction...")
34
  prompt = [0, 0, image.width-5, image.height-5]
35
  inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt")
36
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
44
  medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
45
  mask = (medsam_seg_prob > prediction_threshold).astype(np.uint8)
46
 
47
+ logger.info("Prediction completed.")
48
  return mask
49
 
50
  def get_sidewalk_prediction(image, model, processor, device, threshold=0.7):