File size: 4,082 Bytes
9b3297f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# handler.py

import torch
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
from PIL import Image
import base64
import io
import os
import numpy as np

class EndpointHandler():
    def __init__(self, path=""):
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define label mappings (ensure these match your local environment)
        self.id2label = {
            0: 'background',
            1: 'water',
            2: 'developed',
            3: 'corn',
            4: 'soybeans',
            5: 'wheat',
            6: 'other agriculture',
            7: 'forest/wetlands',
            8: 'open lands',
            9: 'barren'
        }
        self.label2id = {v: k for k, v in self.id2label.items()}

        # Get the token from environment variables
        token = os.getenv("HF_API_TOKEN")

        # Load the model with authentication and consistent configurations
        model_name = "gdurkin/cdl_mask2former_v4_mspc"

        # Initialize the processor and model using from_pretrained
        self.processor = Mask2FormerImageProcessor.from_pretrained(
            model_name,
            use_auth_token=token
        )
        self.model = Mask2FormerForUniversalSegmentation.from_pretrained(
            model_name,
            use_auth_token=token,
            id2label=self.id2label,
            label2id=self.label2id,
            num_labels=len(self.id2label),
            ignore_mismatched_sizes=True,
        )
        self.model.to(self.device)
        self.model.eval()

        # Debugging: Print model configuration
        print("Model configuration:", self.model.config)

    def __call__(self, data):
        try:
            # Parse input data
            if "inputs" in data:
                image_base64 = data["inputs"]
            else:
                return {"error": "No 'inputs' field in request."}

            # Decode the base64 image
            image_bytes = base64.b64decode(image_base64)
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

            # Convert image to NumPy array and normalize to [0, 1]
            image_np = np.array(image).astype(np.float32) / 255.0  # Shape: (H, W, C)

            # Convert image to tensor
            input_tensor = torch.from_numpy(image_np)  # Shape: (H, W, C)

            # Add batch dimension if necessary
            if input_tensor.ndim == 3:
                input_tensor = input_tensor.unsqueeze(0)  # Shape: (1, H, W, C)
            elif input_tensor.ndim != 4:
                return {"error": "Input tensor must be 3D or 4D"}

            # Permute dimensions to (N, C, H, W)
            input_tensor = input_tensor.permute(0, 3, 1, 2)

            input_tensor = input_tensor.to(self.device)

            # Perform inference
            with torch.no_grad():
                outputs = self.model(pixel_values=input_tensor)

            # Post-process the segmentation map
            target_sizes = [(input_tensor.shape[2], input_tensor.shape[3])]
            predicted_segmentation_maps = self.processor.post_process_semantic_segmentation(
                outputs, target_sizes=target_sizes
            )

            predicted_segmentation_map = predicted_segmentation_maps[0]  # This is a tensor


            # Convert the segmentation map to a NumPy array
            seg_map_np = predicted_segmentation_map.cpu().numpy()
         
            #print("class frequencies:", np.unique(seg_map_np, return_counts=True))

            # Convert the segmentation map to a PNG image
            seg_map_pil = Image.fromarray(seg_map_np.astype(np.uint8))

            buffered = io.BytesIO()
            seg_map_pil.save(buffered, format="PNG")
            seg_map_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')

            # Return the segmentation map as a base64 string
            return {'outputs': seg_map_base64}

        except Exception as e:
            # Handle exceptions and return error message
            return {"error": str(e)}