|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
import sys |
|
|
from typing import Dict, List, Any |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from mmcv.runner import load_checkpoint |
|
|
from mmcv.utils import Config |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.realpath(__file__))) |
|
|
|
|
|
|
|
|
from modelsforIML.mmseg.datasets.pipelines import Compose |
|
|
from modelsforIML.mmseg.models import build_segmentor |
|
|
|
|
|
|
|
|
class Pipeline: |
|
|
def __init__(self, model_path: str): |
|
|
""" |
|
|
Initializes the pipeline by loading the model and preprocessing steps. |
|
|
Args: |
|
|
model_path (str): The path to the model checkpoint file. It's automatically |
|
|
passed by the Hugging Face infrastructure. |
|
|
""" |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
config_path = 'models for IML/apscnet.py' |
|
|
|
|
|
checkpoint_path = 'models for IML/APSC-Net.pth' |
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
raise FileNotFoundError( |
|
|
f"Checkpoint file not found at {checkpoint_path}. " |
|
|
"Please download it and place it in the 'models for IML' directory." |
|
|
) |
|
|
|
|
|
cfg = Config.fromfile(config_path) |
|
|
|
|
|
|
|
|
self.model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) |
|
|
load_checkpoint(self.model, checkpoint_path, map_location='cpu') |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
|
|
test_pipeline_cfg = cfg.data.test.pipeline[1]['transforms'] |
|
|
self.pipeline = Compose(test_pipeline_cfg) |
|
|
|
|
|
def __call__(self, inputs: Image.Image) -> Dict[str, Any]: |
|
|
""" |
|
|
Performs inference on a single image. |
|
|
Args: |
|
|
inputs (Image.Image): A PIL Image to be processed. |
|
|
Returns: |
|
|
Dict[str, Any]: A dictionary containing the resulting mask as a base64 encoded string. |
|
|
""" |
|
|
|
|
|
img = np.array(inputs.convert('RGB')) |
|
|
|
|
|
|
|
|
data = {'img': img, 'img_shape': img.shape, 'ori_shape': img.shape} |
|
|
data = self.pipeline(data) |
|
|
|
|
|
|
|
|
img_tensor = data['img'][0].unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
result = self.model(return_loss=False, img=[img_tensor]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_pred = result[0].argmax(0).astype(np.uint8) |
|
|
|
|
|
|
|
|
mask_pred *= 255 |
|
|
|
|
|
|
|
|
mask_image = Image.fromarray(mask_pred, mode='L') |
|
|
|
|
|
|
|
|
buffered = io.BytesIO() |
|
|
mask_image.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
return {"image": img_str} |