|
|
|
|
|
|
|
|
import io |
|
|
import base64 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import SamModel, SamProcessor |
|
|
from typing import Dict, List, Any |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Called once at startup. |
|
|
Load the SAM model using Hugging Face Transformers. |
|
|
""" |
|
|
try: |
|
|
|
|
|
self.model = SamModel.from_pretrained(path).to(device) |
|
|
self.processor = SamProcessor.from_pretrained(path) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"Failed to load from local path: {e}") |
|
|
print("Attempting to load from facebook/sam-vit-base") |
|
|
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) |
|
|
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") |
|
|
|
|
|
def __call__(self, data: Any) -> Any: |
|
|
""" |
|
|
Called on every HTTP request. |
|
|
Args: |
|
|
data (:obj:): |
|
|
includes the input data and the parameters for the inference. |
|
|
""" |
|
|
inputs = data.pop("inputs", data) |
|
|
parameters = data.pop("parameters", {}) |
|
|
|
|
|
img = Image.open(io.BytesIO(inputs)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
height, width = img.size[1], img.size[0] |
|
|
|
|
|
|
|
|
input_points = [[[width // 2, height // 2]]] |
|
|
input_labels = [[1]] |
|
|
|
|
|
|
|
|
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
|
|
|
try: |
|
|
|
|
|
original_height, original_width = inputs["original_sizes"][0].tolist() |
|
|
|
|
|
|
|
|
pred_masks = outputs.pred_masks.cpu() |
|
|
iou_scores = outputs.iou_scores.cpu()[0] |
|
|
|
|
|
|
|
|
if pred_masks.ndim == 5: |
|
|
pred_masks = pred_masks.squeeze(1) |
|
|
|
|
|
|
|
|
best_mask_idx = torch.argmax(iou_scores) |
|
|
best_mask_tensor = pred_masks[0, best_mask_idx, :, :] |
|
|
|
|
|
|
|
|
|
|
|
upscaled_mask = F.interpolate( |
|
|
best_mask_tensor.unsqueeze(0).unsqueeze(0).float(), |
|
|
size=(original_height, original_width), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
).squeeze() |
|
|
|
|
|
|
|
|
mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing masks: {e}") |
|
|
|
|
|
height, width = img.size[1], img.size[0] |
|
|
mask_binary = np.zeros((height, width), dtype=np.uint8) |
|
|
center_x, center_y = width // 2, height // 2 |
|
|
size = min(width, height) // 8 |
|
|
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255 |
|
|
|
|
|
|
|
|
out = io.BytesIO() |
|
|
Image.fromarray(mask_binary).save(out, format="PNG") |
|
|
out.seek(0) |
|
|
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8') |
|
|
|
|
|
|
|
|
mask_bytes = base64.b64decode(mask_base64) |
|
|
mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return mask_img |
|
|
|
|
|
def main(): |
|
|
|
|
|
input_path = "/Users/rp7/Downloads/test.jpeg" |
|
|
output_path = "output.jpg" |
|
|
|
|
|
|
|
|
with open(input_path, "rb") as f: |
|
|
img_bytes = f.read() |
|
|
img_b64 = base64.b64encode(img_bytes).decode("utf-8") |
|
|
data_url = f"data:image/jpeg;base64,{img_b64}" |
|
|
|
|
|
handler = EndpointHandler(path=".") |
|
|
result = handler({"inputs": data_url})[0] |
|
|
|
|
|
|
|
|
mask_bytes = base64.b64decode(result["mask_png_base64"]) |
|
|
mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB") |
|
|
mask_img.save(output_path, format="JPEG") |
|
|
print(f"Wrote mask to {output_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|