Spaces:
Build error
Build error
| from transformers import SamModel, SamConfig, SamProcessor | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import app | |
| import os | |
| import json | |
| from PIL import Image | |
| def pred(src): | |
| # -- cache | |
| cache_dir = "/code/cache" | |
| # -- load model configuration | |
| MODEL_FILE = "sam_model.pth" | |
| model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir) | |
| model = SamModel(config=model_config) | |
| model.load_state_dict(torch.load(MODEL_FILE, map_location=torch.device('cpu'))) | |
| with open("sam-config.json", "r") as f: # modified config json file | |
| modified_config_dict = json.load(f) | |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-base", | |
| **modified_config_dict, | |
| cache_dir=cache_dir) | |
| # -- process image | |
| image = Image.open(src) | |
| rgbim = image.convert("RGB") | |
| new_image = np.array(rgbim) | |
| print() | |
| print("image shape:",new_image.shape) | |
| inputs = processor(new_image, return_tensors="pt") | |
| model.eval() | |
| # forward pass | |
| print("predicting...") | |
| with torch.no_grad(): | |
| outputs = model(pixel_values=inputs["pixel_values"], | |
| multimask_output=False) | |
| # apply sigmoid | |
| print("apply sigmoid...") | |
| pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
| # convert soft mask to hard mask | |
| PROBABILITY_THRES = 0.30 | |
| pred_prob = pred_prob.cpu().numpy().squeeze() | |
| pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8) | |
| return pred_prob, pred_prediction | |