|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
self.id2label = { |
|
|
0: 'background', |
|
|
1: 'water', |
|
|
2: 'developed', |
|
|
3: 'corn', |
|
|
4: 'soybeans', |
|
|
5: 'wheat', |
|
|
6: 'other agriculture', |
|
|
7: 'forest/wetlands', |
|
|
8: 'open lands', |
|
|
9: 'barren' |
|
|
} |
|
|
self.label2id = {v: k for k, v in self.id2label.items()} |
|
|
|
|
|
|
|
|
token = os.getenv("HF_API_TOKEN") |
|
|
|
|
|
|
|
|
model_name = "gdurkin/cdl_mask2former_v4_mspc" |
|
|
|
|
|
|
|
|
self.processor = Mask2FormerImageProcessor.from_pretrained( |
|
|
model_name, |
|
|
use_auth_token=token |
|
|
) |
|
|
self.model = Mask2FormerForUniversalSegmentation.from_pretrained( |
|
|
model_name, |
|
|
use_auth_token=token, |
|
|
id2label=self.id2label, |
|
|
label2id=self.label2id, |
|
|
num_labels=len(self.id2label), |
|
|
ignore_mismatched_sizes=True, |
|
|
) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
print("Model configuration:", self.model.config) |
|
|
|
|
|
def __call__(self, data): |
|
|
try: |
|
|
|
|
|
if "inputs" in data: |
|
|
image_base64 = data["inputs"] |
|
|
else: |
|
|
return {"error": "No 'inputs' field in request."} |
|
|
|
|
|
|
|
|
image_bytes = base64.b64decode(image_base64) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
image_np = np.array(image).astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
input_tensor = torch.from_numpy(image_np) |
|
|
|
|
|
|
|
|
if input_tensor.ndim == 3: |
|
|
input_tensor = input_tensor.unsqueeze(0) |
|
|
elif input_tensor.ndim != 4: |
|
|
return {"error": "Input tensor must be 3D or 4D"} |
|
|
|
|
|
|
|
|
input_tensor = input_tensor.permute(0, 3, 1, 2) |
|
|
|
|
|
input_tensor = input_tensor.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(pixel_values=input_tensor) |
|
|
|
|
|
|
|
|
target_sizes = [(input_tensor.shape[2], input_tensor.shape[3])] |
|
|
predicted_segmentation_maps = self.processor.post_process_semantic_segmentation( |
|
|
outputs, target_sizes=target_sizes |
|
|
) |
|
|
|
|
|
predicted_segmentation_map = predicted_segmentation_maps[0] |
|
|
|
|
|
|
|
|
|
|
|
seg_map_np = predicted_segmentation_map.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seg_map_pil = Image.fromarray(seg_map_np.astype(np.uint8)) |
|
|
|
|
|
buffered = io.BytesIO() |
|
|
seg_map_pil.save(buffered, format="PNG") |
|
|
seg_map_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
|
|
|
|
|
|
return {'outputs': seg_map_base64} |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
return {"error": str(e)} |
|
|
|