Trang Dang
adjust size
29f785c
Raw
History Blame Contribute Delete
1.25 kB
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import app
from PIL import Image
def pred(src):
# Load the model configuration
cache_dir = "/code/cache"
model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
# Create an instance of the model architecture with the loaded configuration
my_sam_model = SamModel(config=model_config)
#Update the model by loading the weights from saved file
my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
new_image = np.array(Image.open(src).convert("RGB"))
inputs = processor(new_image, return_tensors="pt")
my_sam_model.eval()
# # forward pass
with torch.no_grad():
outputs = my_sam_model(**inputs, multimask_output=False)
# # apply sigmoid
single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# # convert soft mask to hard mask
single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
return single_patch_prob, single_patch_prediction