| |
| |
|
|
| import os |
| import base64 |
| import torch |
| import numpy as np |
| from PIL import Image |
| import io |
|
|
| class BaseHandler: |
| def __init__(self): |
| """Initialize the handler with model-specific configurations""" |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = None |
| self.initialized = False |
| |
| def initialize(self): |
| """Load model and other resources""" |
| |
| raise NotImplementedError |
| |
| def preprocess(self, data): |
| """Preprocess the input data""" |
| |
| raise NotImplementedError |
| |
| def inference(self, inputs): |
| """Run inference with the preprocessed inputs""" |
| |
| raise NotImplementedError |
| |
| def postprocess(self, inference_output): |
| """Post-process the model output""" |
| |
| raise NotImplementedError |
| |
| def __call__(self, data): |
| """Handle a request to the model""" |
| |
| if not self.initialized: |
| self.initialize() |
| |
| |
| preprocessed_data = self.preprocess(data) |
| inference_output = self.inference(preprocessed_data) |
| output = self.postprocess(inference_output) |
| |
| return output |
| |
| def encode_image(self, image): |
| """Encode a PIL Image to base64""" |
| buffered = io.BytesIO() |
| image.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| return img_str |
| |
| def decode_image(self, image_str): |
| """Decode a base64 string to PIL Image""" |
| img_data = base64.b64decode(image_str) |
| return Image.open(io.BytesIO(img_data)) |
| |
| def svg_to_base64(self, svg_content): |
| """Convert SVG content to base64""" |
| return base64.b64encode(svg_content.encode("utf-8")).decode("utf-8") |