File size: 4,714 Bytes
f83cef1 5d6739d 878a499 3a50531 878a499 3a50531 f83cef1 6491258 3a50531 6491258 5d6739d 6491258 5d6739d 6491258 85c4bdf 6491258 3a50531 878a499 85c4bdf 3a50531 878a499 3988a1a 3a50531 878a499 6491258 3a50531 6491258 3a50531 6491258 3a50531 7c8801e 6dbf008 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import base64
import logging
import json
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir):
# Load model and move to CPU or GPU as available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = torch.jit.load(f"{model_dir}/model_scripted_efficientnet.pt", map_location=self.device)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
self.supported_issues = [
"Dark Spots",
"Dry Lips",
"Forehead Wrinkles",
"Jowls",
"Nasolabial Folds",
"Prejowl Sulcus",
"Thin Lips",
"Under Eye Hollow",
"Under Eye Wrinkles",
"Brow Asymmetry"
]
def __call__(self, data):
logger.info(f"Received data: {type(data)}")
image = None
try:
# Handle string input (from Hugging Face interface)
if isinstance(data, str):
logger.info("Input is string. Attempting to parse as JSON.")
data = json.loads(data)
# Handle various input formats
if isinstance(data, dict):
if "inputs" in data:
input_data = data["inputs"]
logger.info(f"Input data type: {type(input_data)}")
# Handle base64 encoded string
if isinstance(input_data, str):
logger.info("Attempting to decode base64 string")
try:
# Remove potential base64 prefix
if "base64," in input_data:
input_data = input_data.split("base64,")[1]
image_bytes = base64.b64decode(input_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
logger.error(f"Failed to decode base64: {str(e)}")
# Handle raw bytes
elif isinstance(input_data, bytes):
logger.info("Processing raw bytes input")
image = Image.open(io.BytesIO(input_data)).convert("RGB")
# Handle list input (from Hugging Face interface)
elif isinstance(input_data, list):
logger.info("Processing list input")
if len(input_data) > 0 and isinstance(input_data[0], str):
try:
# Remove potential base64 prefix
if "base64," in input_data[0]:
input_data[0] = input_data[0].split("base64,")[1]
image_bytes = base64.b64decode(input_data[0])
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
logger.error(f"Failed to decode base64 from list: {str(e)}")
# Handle direct bytes input
elif isinstance(data, bytes):
logger.info("Processing direct bytes input")
image = Image.open(io.BytesIO(data)).convert("RGB")
except Exception as e:
logger.error(f"Error processing input: {str(e)}")
raise ValueError(f"Error processing input: {str(e)}")
if image is None:
logger.error("Could not load image from input data")
raise ValueError("Could not load image from input data")
logger.info("Image loaded successfully. Applying transformations.")
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
logger.info("Running inference.")
outputs = self.model(image_tensor)
predictions = outputs.squeeze().tolist()
output = [issue for issue, prob in zip(self.supported_issues, predictions) if prob > 0.5]
logger.info(f"Predictions: {output}")
return {"predictions": output}
EndpointHandler = EndpointHandler # Crucial for import |