File size: 4,714 Bytes
f83cef1
 
 
5d6739d
878a499
3a50531
878a499
3a50531
 
 
 
f83cef1
6491258
 
3a50531
 
 
6491258
 
5d6739d
 
 
6491258
 
5d6739d
 
6491258
 
 
 
 
 
 
 
 
 
 
 
85c4bdf
6491258
3a50531
 
878a499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85c4bdf
3a50531
878a499
 
3988a1a
3a50531
 
878a499
6491258
3a50531
6491258
3a50531
 
6491258
3a50531
7c8801e
 
6dbf008
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import base64
import logging
import json

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, model_dir):
        # Load model and move to CPU or GPU as available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = torch.jit.load(f"{model_dir}/model_scripted_efficientnet.pt", map_location=self.device)
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        self.supported_issues = [
            "Dark Spots",
            "Dry Lips",
            "Forehead Wrinkles",
            "Jowls",
            "Nasolabial Folds",
            "Prejowl Sulcus",
            "Thin Lips",
            "Under Eye Hollow",
            "Under Eye Wrinkles",
            "Brow Asymmetry"
        ]

    def __call__(self, data):
        logger.info(f"Received data: {type(data)}")
        image = None
        
        try:
            # Handle string input (from Hugging Face interface)
            if isinstance(data, str):
                logger.info("Input is string. Attempting to parse as JSON.")
                data = json.loads(data)
            
            # Handle various input formats
            if isinstance(data, dict):
                if "inputs" in data:
                    input_data = data["inputs"]
                    logger.info(f"Input data type: {type(input_data)}")
                    
                    # Handle base64 encoded string
                    if isinstance(input_data, str):
                        logger.info("Attempting to decode base64 string")
                        try:
                            # Remove potential base64 prefix
                            if "base64," in input_data:
                                input_data = input_data.split("base64,")[1]
                            image_bytes = base64.b64decode(input_data)
                            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
                        except Exception as e:
                            logger.error(f"Failed to decode base64: {str(e)}")
                    
                    # Handle raw bytes
                    elif isinstance(input_data, bytes):
                        logger.info("Processing raw bytes input")
                        image = Image.open(io.BytesIO(input_data)).convert("RGB")
                    
                    # Handle list input (from Hugging Face interface)
                    elif isinstance(input_data, list):
                        logger.info("Processing list input")
                        if len(input_data) > 0 and isinstance(input_data[0], str):
                            try:
                                # Remove potential base64 prefix
                                if "base64," in input_data[0]:
                                    input_data[0] = input_data[0].split("base64,")[1]
                                image_bytes = base64.b64decode(input_data[0])
                                image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
                            except Exception as e:
                                logger.error(f"Failed to decode base64 from list: {str(e)}")
            
            # Handle direct bytes input
            elif isinstance(data, bytes):
                logger.info("Processing direct bytes input")
                image = Image.open(io.BytesIO(data)).convert("RGB")
                
        except Exception as e:
            logger.error(f"Error processing input: {str(e)}")
            raise ValueError(f"Error processing input: {str(e)}")

        if image is None:
            logger.error("Could not load image from input data")
            raise ValueError("Could not load image from input data")

        logger.info("Image loaded successfully. Applying transformations.")
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            logger.info("Running inference.")
            outputs = self.model(image_tensor)
        
        predictions = outputs.squeeze().tolist()
        output = [issue for issue, prob in zip(self.supported_issues, predictions) if prob > 0.5]
        logger.info(f"Predictions: {output}")
        return {"predictions": output}

EndpointHandler = EndpointHandler  # Crucial for import