File size: 4,079 Bytes
df8aa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Dict, Any
from transformers import Sam3Processor, Sam3Model
import torch
from PIL import Image
import io
import base64
import numpy as np


class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the SAM3 model and processor for text-prompted segmentation.
        
        Args:
            path: Path to local model files (if deploying with custom weights)
                  or empty string to use the default facebook/sam3 model
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Use local path if provided, otherwise use the default model
        model_id = path if path else "facebook/sam3"
        
        self.processor = Sam3Processor.from_pretrained(model_id)
        self.model = Sam3Model.from_pretrained(model_id).to(self.device)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process an image with a text prompt and return segmentation masks.
        
        Expected input format:
        {
            "inputs": {
                "image": "<base64_encoded_image>",
                "prompt": "text description of object to segment"  # e.g., "a red car"
            }
        }
        
        Returns:
        {
            "masks": [...],  # List of binary masks as base64 encoded PNGs
            "boxes": [...],  # Bounding boxes in xyxy format
            "scores": [...]  # Confidence scores
        }
        """
        # 1. Extract inputs
        inputs = data.pop("inputs", data)
        image_b64 = inputs.get("image")
        text_prompt = inputs.get("prompt", None)
        
        # Optional parameters
        threshold = inputs.get("threshold", 0.5)
        mask_threshold = inputs.get("mask_threshold", 0.5)
        
        if not image_b64:
            return {"error": "No image provided. Please provide a base64-encoded image."}
        
        if not text_prompt:
            return {"error": "No text prompt provided. Please provide a 'prompt' field."}
        
        # 2. Decode image
        try:
            image_data = base64.b64decode(image_b64)
            image = Image.open(io.BytesIO(image_data)).convert("RGB")
        except Exception as e:
            return {"error": f"Failed to decode image: {str(e)}"}

        # 3. Process inputs with text prompt
        processor_inputs = self.processor(
            images=image, 
            text=text_prompt, 
            return_tensors="pt"
        ).to(self.device)

        # 4. Run Inference
        with torch.no_grad():
            outputs = self.model(**processor_inputs)

        # 5. Post-process results
        results = self.processor.post_process_instance_segmentation(
            outputs,
            threshold=threshold,
            mask_threshold=mask_threshold,
            target_sizes=processor_inputs.get("original_sizes").tolist()
        )[0]
        
        # 6. Format response
        response = {
            "masks": [],
            "boxes": [],
            "scores": []
        }
        
        if len(results["masks"]) > 0:
            # Convert masks to base64-encoded PNGs
            for mask in results["masks"]:
                # Convert boolean mask to uint8 image
                mask_np = mask.cpu().numpy().astype(np.uint8) * 255
                mask_img = Image.fromarray(mask_np, mode="L")
                
                # Encode as base64 PNG
                buffer = io.BytesIO()
                mask_img.save(buffer, format="PNG")
                mask_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
                response["masks"].append(mask_b64)
            
            # Convert boxes to list
            if "boxes" in results:
                response["boxes"] = results["boxes"].cpu().tolist()
            
            # Convert scores to list
            if "scores" in results:
                response["scores"] = results["scores"].cpu().tolist()
        
        response["num_objects"] = len(response["masks"])
        
        return response