realtime-fa-test / handler.py
sam-brause
try claude image solution test
878a499
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import base64
import logging
import json
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir):
# Load model and move to CPU or GPU as available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = torch.jit.load(f"{model_dir}/model_scripted_efficientnet.pt", map_location=self.device)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
self.supported_issues = [
"Dark Spots",
"Dry Lips",
"Forehead Wrinkles",
"Jowls",
"Nasolabial Folds",
"Prejowl Sulcus",
"Thin Lips",
"Under Eye Hollow",
"Under Eye Wrinkles",
"Brow Asymmetry"
]
def __call__(self, data):
logger.info(f"Received data: {type(data)}")
image = None
try:
# Handle string input (from Hugging Face interface)
if isinstance(data, str):
logger.info("Input is string. Attempting to parse as JSON.")
data = json.loads(data)
# Handle various input formats
if isinstance(data, dict):
if "inputs" in data:
input_data = data["inputs"]
logger.info(f"Input data type: {type(input_data)}")
# Handle base64 encoded string
if isinstance(input_data, str):
logger.info("Attempting to decode base64 string")
try:
# Remove potential base64 prefix
if "base64," in input_data:
input_data = input_data.split("base64,")[1]
image_bytes = base64.b64decode(input_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
logger.error(f"Failed to decode base64: {str(e)}")
# Handle raw bytes
elif isinstance(input_data, bytes):
logger.info("Processing raw bytes input")
image = Image.open(io.BytesIO(input_data)).convert("RGB")
# Handle list input (from Hugging Face interface)
elif isinstance(input_data, list):
logger.info("Processing list input")
if len(input_data) > 0 and isinstance(input_data[0], str):
try:
# Remove potential base64 prefix
if "base64," in input_data[0]:
input_data[0] = input_data[0].split("base64,")[1]
image_bytes = base64.b64decode(input_data[0])
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
logger.error(f"Failed to decode base64 from list: {str(e)}")
# Handle direct bytes input
elif isinstance(data, bytes):
logger.info("Processing direct bytes input")
image = Image.open(io.BytesIO(data)).convert("RGB")
except Exception as e:
logger.error(f"Error processing input: {str(e)}")
raise ValueError(f"Error processing input: {str(e)}")
if image is None:
logger.error("Could not load image from input data")
raise ValueError("Could not load image from input data")
logger.info("Image loaded successfully. Applying transformations.")
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
logger.info("Running inference.")
outputs = self.model(image_tensor)
predictions = outputs.squeeze().tolist()
output = [issue for issue, prob in zip(self.supported_issues, predictions) if prob > 0.5]
logger.info(f"Predictions: {output}")
return {"predictions": output}
EndpointHandler = EndpointHandler # Crucial for import