import torch import torchvision.transforms as transforms from PIL import Image import io import base64 import logging import json # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, model_dir): # Load model and move to CPU or GPU as available self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = torch.jit.load(f"{model_dir}/model_scripted_efficientnet.pt", map_location=self.device) self.model.eval() self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) self.supported_issues = [ "Dark Spots", "Dry Lips", "Forehead Wrinkles", "Jowls", "Nasolabial Folds", "Prejowl Sulcus", "Thin Lips", "Under Eye Hollow", "Under Eye Wrinkles", "Brow Asymmetry" ] def __call__(self, data): logger.info(f"Received data: {type(data)}") image = None try: # Handle string input (from Hugging Face interface) if isinstance(data, str): logger.info("Input is string. Attempting to parse as JSON.") data = json.loads(data) # Handle various input formats if isinstance(data, dict): if "inputs" in data: input_data = data["inputs"] logger.info(f"Input data type: {type(input_data)}") # Handle base64 encoded string if isinstance(input_data, str): logger.info("Attempting to decode base64 string") try: # Remove potential base64 prefix if "base64," in input_data: input_data = input_data.split("base64,")[1] image_bytes = base64.b64decode(input_data) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception as e: logger.error(f"Failed to decode base64: {str(e)}") # Handle raw bytes elif isinstance(input_data, bytes): logger.info("Processing raw bytes input") image = Image.open(io.BytesIO(input_data)).convert("RGB") # Handle list input (from Hugging Face interface) elif isinstance(input_data, list): logger.info("Processing list input") if len(input_data) > 0 and isinstance(input_data[0], str): try: # Remove potential base64 prefix if "base64," in input_data[0]: input_data[0] = input_data[0].split("base64,")[1] image_bytes = base64.b64decode(input_data[0]) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception as e: logger.error(f"Failed to decode base64 from list: {str(e)}") # Handle direct bytes input elif isinstance(data, bytes): logger.info("Processing direct bytes input") image = Image.open(io.BytesIO(data)).convert("RGB") except Exception as e: logger.error(f"Error processing input: {str(e)}") raise ValueError(f"Error processing input: {str(e)}") if image is None: logger.error("Could not load image from input data") raise ValueError("Could not load image from input data") logger.info("Image loaded successfully. Applying transformations.") image_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): logger.info("Running inference.") outputs = self.model(image_tensor) predictions = outputs.squeeze().tolist() output = [issue for issue, prob in zip(self.supported_issues, predictions) if prob > 0.5] logger.info(f"Predictions: {output}") return {"predictions": output} EndpointHandler = EndpointHandler # Crucial for import