sam2-hiera-large / handler.py
mila2030's picture
Upload 20 files
4a0358b verified
from typing import Dict, Any
from transformers import pipeline
import torch
from PIL import Image
import numpy as np
class EndpointHandler:
def __init__(self, path: str):
"""
Initialize the handler, load the SAM2 model.
"""
# Load SAM2 model and prepare pipeline
self.model = pipeline("image-segmentation", model=path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Perform inference on the input data and return the result.
"""
image_data = data["inputs"] # Assuming inputs key contains the image
point_coords = data["point_coords"]
point_labels = data["point_labels"]
# Convert image from base64 or other formats to numpy array if necessary
# Assuming image_data is already in a suitable format for model inference
# Running the inference with SAM2 model
segmentation_result = self.model(image_data)
return {"result": segmentation_result} # Return results in a dictionary