File size: 1,942 Bytes
bc1fb7d
b46360a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc1fb7d
 
 
 
 
 
 
 
 
b46360a
 
bc1fb7d
b46360a
 
 
 
 
 
 
 
 
 
 
 
 
bc1fb7d
b46360a
 
 
 
bc1fb7d
b46360a
bc1fb7d
b46360a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Image preprocessing for MelanoScope AI."""
import logging
from typing import Union, Optional
import numpy as np
from PIL import Image
from torchvision import transforms
from ..config.settings import ModelConfig

logger = logging.getLogger(__name__)

class ImagePreprocessor:
    """Handles image preprocessing for model inference."""
    
    def __init__(self):
        self.transforms = self._create_transform_pipeline()
        logger.info("ImagePreprocessor initialized")
    
    def _create_transform_pipeline(self) -> transforms.Compose:
        """Create image transformation pipeline."""
        return transforms.Compose([
            transforms.Resize(ModelConfig.IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=ModelConfig.NORMALIZATION_MEAN,
                std=ModelConfig.NORMALIZATION_STD
            ),
        ])
    
    def preprocess(self, image_input: Union[Image.Image, np.ndarray]) -> Optional[np.ndarray]:
        """Preprocess image for model inference."""
        try:
            pil_image = self._convert_to_pil(image_input)
            if pil_image is None:
                return None
            
            tensor = self.transforms(pil_image).unsqueeze(0).numpy()
            return tensor
            
        except Exception as e:
            logger.error(f"Error preprocessing image: {e}")
            return None
    
    def _convert_to_pil(self, image_input: Union[Image.Image, np.ndarray]) -> Optional[Image.Image]:
        """Convert image input to PIL Image in RGB mode."""
        try:
            if isinstance(image_input, Image.Image):
                return image_input.convert("RGB")
            else:
                return Image.fromarray(image_input).convert("RGB")
        except Exception as e:
            logger.error(f"Error converting to PIL: {e}")
            return None