Spaces:
Build error
Build error
| from transformers import SamModel, SamConfig, SamProcessor | |
| import torch | |
| import numpy as np | |
| import app | |
| from PIL import Image | |
| def pred(src): | |
| # Load the model configuration | |
| cache_dir = "/code/cache" | |
| model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir) | |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir) | |
| # Create an instance of the model architecture with the loaded configuration | |
| my_sam_model = SamModel(config=model_config) | |
| #Update the model by loading the weights from saved file | |
| my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu'))) | |
| new_image = np.array(Image.open(src).convert("RGB")) | |
| inputs = processor(new_image, return_tensors="pt") | |
| my_sam_model.eval() | |
| # # forward pass | |
| with torch.no_grad(): | |
| outputs = my_sam_model(**inputs, multimask_output=False) | |
| # # apply sigmoid | |
| single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
| # # convert soft mask to hard mask | |
| single_patch_prob = single_patch_prob.cpu().numpy().squeeze() | |
| single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8) | |
| return single_patch_prob, single_patch_prediction | |