|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
import logging |
|
|
import json |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir): |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model = torch.jit.load(f"{model_dir}/model_scripted_efficientnet.pt", map_location=self.device) |
|
|
self.model.eval() |
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
self.supported_issues = [ |
|
|
"Dark Spots", |
|
|
"Dry Lips", |
|
|
"Forehead Wrinkles", |
|
|
"Jowls", |
|
|
"Nasolabial Folds", |
|
|
"Prejowl Sulcus", |
|
|
"Thin Lips", |
|
|
"Under Eye Hollow", |
|
|
"Under Eye Wrinkles", |
|
|
"Brow Asymmetry" |
|
|
] |
|
|
|
|
|
def __call__(self, data): |
|
|
logger.info(f"Received data: {type(data)}") |
|
|
image = None |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(data, str): |
|
|
logger.info("Input is string. Attempting to parse as JSON.") |
|
|
data = json.loads(data) |
|
|
|
|
|
|
|
|
if isinstance(data, dict): |
|
|
if "inputs" in data: |
|
|
input_data = data["inputs"] |
|
|
logger.info(f"Input data type: {type(input_data)}") |
|
|
|
|
|
|
|
|
if isinstance(input_data, str): |
|
|
logger.info("Attempting to decode base64 string") |
|
|
try: |
|
|
|
|
|
if "base64," in input_data: |
|
|
input_data = input_data.split("base64,")[1] |
|
|
image_bytes = base64.b64decode(input_data) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to decode base64: {str(e)}") |
|
|
|
|
|
|
|
|
elif isinstance(input_data, bytes): |
|
|
logger.info("Processing raw bytes input") |
|
|
image = Image.open(io.BytesIO(input_data)).convert("RGB") |
|
|
|
|
|
|
|
|
elif isinstance(input_data, list): |
|
|
logger.info("Processing list input") |
|
|
if len(input_data) > 0 and isinstance(input_data[0], str): |
|
|
try: |
|
|
|
|
|
if "base64," in input_data[0]: |
|
|
input_data[0] = input_data[0].split("base64,")[1] |
|
|
image_bytes = base64.b64decode(input_data[0]) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to decode base64 from list: {str(e)}") |
|
|
|
|
|
|
|
|
elif isinstance(data, bytes): |
|
|
logger.info("Processing direct bytes input") |
|
|
image = Image.open(io.BytesIO(data)).convert("RGB") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing input: {str(e)}") |
|
|
raise ValueError(f"Error processing input: {str(e)}") |
|
|
|
|
|
if image is None: |
|
|
logger.error("Could not load image from input data") |
|
|
raise ValueError("Could not load image from input data") |
|
|
|
|
|
logger.info("Image loaded successfully. Applying transformations.") |
|
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logger.info("Running inference.") |
|
|
outputs = self.model(image_tensor) |
|
|
|
|
|
predictions = outputs.squeeze().tolist() |
|
|
output = [issue for issue, prob in zip(self.supported_issues, predictions) if prob > 0.5] |
|
|
logger.info(f"Predictions: {output}") |
|
|
return {"predictions": output} |
|
|
|
|
|
EndpointHandler = EndpointHandler |