| |
|
|
| import io |
| import base64 |
| import numpy as np |
| from PIL import Image |
| import torch |
| from transformers import SamModel, SamProcessor |
| from typing import Dict, List, Any |
| import torch.nn.functional as F |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| """ |
| Called once at startup. |
| Load the SAM model using Hugging Face Transformers. |
| """ |
| try: |
| |
| self.model = SamModel.from_pretrained(path).to(device) |
| self.processor = SamProcessor.from_pretrained(path) |
| except Exception as e: |
| |
| print("Failed to load from local path: {}".format(e)) |
| print("Attempting to load from facebook/sam-vit-base") |
| self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) |
| self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") |
|
|
| def __call__(self, data): |
| """ |
| Called on every HTTP request. |
| Handles both base64-encoded images and PIL images. |
| Returns a PIL Image object. |
| """ |
| |
| inputs = data.pop("inputs", None) |
| if inputs is None: |
| raise ValueError("Missing 'inputs' key in the payload.") |
|
|
| |
| if isinstance(inputs, Image.Image): |
| img = inputs.convert("RGB") |
| elif isinstance(inputs, str): |
| if inputs.startswith("data:"): |
| inputs = inputs.split(",", 1)[1] |
| image_bytes = base64.b64decode(inputs) |
| img = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| else: |
| raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.") |
| |
| |
| height, width = img.size[1], img.size[0] |
| input_points = [[[width // 2, height // 2]]] |
| input_labels = [[1]] |
| |
| inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| |
| |
| try: |
| original_height, original_width = inputs["original_sizes"][0].tolist() |
| pred_masks = outputs.pred_masks.cpu() |
| iou_scores = outputs.iou_scores.cpu()[0] |
|
|
| if pred_masks.ndim == 5: |
| pred_masks = pred_masks.squeeze(1) |
|
|
| best_mask_idx = torch.argmax(iou_scores) |
| best_mask_tensor = pred_masks[0, best_mask_idx, :, :] |
|
|
| upscaled_mask = F.interpolate( |
| best_mask_tensor.unsqueeze(0).unsqueeze(0).float(), |
| size=(original_height, original_width), |
| mode='bilinear', |
| align_corners=False |
| ).squeeze() |
|
|
| mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255 |
| |
| except Exception as e: |
| print("Error processing masks: {}".format(e)) |
| mask_binary = np.zeros((height, width), dtype=np.uint8) |
| center_x, center_y = width // 2, height // 2 |
| size = min(width, height) // 8 |
| mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255 |
| |
| |
| output_img = Image.fromarray(mask_binary) |
| return output_img |
|
|
| def main(): |
| |
| input_path = "/Users/rp7/Downloads/test.jpeg" |
| output_path = "output.png" |
|
|
| |
| with open(input_path, "rb") as f: |
| img_bytes = f.read() |
| img_b64 = base64.b64encode(img_bytes).decode("utf-8") |
| payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)} |
|
|
| |
| handler = EndpointHandler(path=".") |
| result_img = handler(payload) |
|
|
| |
| result_img.save(output_path) |
| print("Wrote mask to {}".format(output_path)) |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|