File size: 3,042 Bytes
592adee
2f7cfdc
592adee
d05bd8d
592adee
 
d816a26
 
d05bd8d
592adee
d05bd8d
 
2f7cfdc
 
d816a26
2f7cfdc
d816a26
 
 
 
 
 
 
 
 
 
592adee
d05bd8d
2f7cfdc
 
d05bd8d
2f7cfdc
d05bd8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f7cfdc
d05bd8d
2f7cfdc
 
d816a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d05bd8d
2f7cfdc
d816a26
2f7cfdc
d05bd8d
592adee
d05bd8d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# handler.py

import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any

class EndpointHandler():
    def __init__(self, path=""):
        """
        Called once at startup. 
        Load the SAM model using Hugging Face Transformers.
        """
        try:
            # Load the model and processor from the local path
            self.model = SamModel.from_pretrained(path)
            self.processor = SamProcessor.from_pretrained(path)
        except Exception as e:
            # Fallback to loading from a known SAM model if local loading fails
            print(f"Failed to load from local path: {e}")
            print("Attempting to load from facebook/sam-vit-base")
            self.model = SamModel.from_pretrained("facebook/sam-vit-base")
            self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Called on every HTTP request.
        Expecting base64 encoded image in the 'inputs' field or 'image' field.
        """
        # Handle different input formats
        if "inputs" in data:
            if isinstance(data["inputs"], str):
                # Base64 encoded image
                image_bytes = base64.b64decode(data["inputs"])
            elif isinstance(data["inputs"], dict) and "image" in data["inputs"]:
                # Nested structure with image field
                image_bytes = base64.b64decode(data["inputs"]["image"])
            else:
                raise ValueError("Invalid input format. Expected base64 encoded image string.")
        elif "image" in data:
            # Direct image field
            image_bytes = base64.b64decode(data["image"])
        else:
            raise ValueError("No image found in request. Expected 'inputs' or 'image' field with base64 encoded image.")
        
        # Process the image
        img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        
        # Prepare inputs for the model
        inputs = self.processor(img, return_tensors="pt")
        
        # Generate masks using the model
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        # Process the outputs to get masks
        masks = self.processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(), 
            inputs["original_sizes"].cpu(), 
            inputs["reshaped_input_sizes"].cpu()
        )[0]
        
        # Convert the first mask to a binary mask
        mask = masks[0].squeeze().numpy()
        mask_binary = (mask > 0.0).astype(np.uint8) * 255
        
        # Convert result to base64
        out = io.BytesIO()
        Image.fromarray(mask_binary).save(out, format="PNG")
        out.seek(0)
        mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')

        # Return in the expected format
        return [{"mask_png_base64": mask_base64, "num_masks": len(masks)}]