File size: 2,552 Bytes
547626f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from PIL import Image
import io

CLASSES = ["clear", "acne", "ros", "black"]
IMG_SIZE = 224


class SkinClassifier:
    def __init__(self, model_path="model/stage1_skin_classifier.pth"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.classes = CLASSES
        self.img_size = IMG_SIZE

        # Инициализируем модель
        self.model = timm.create_model(
            "efficientnet_b0",
            pretrained=False,
            num_classes=len(self.classes)
        )

        # Загружаем веса
        state_dict = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()

        # Трансформации
        self.transform = A.Compose([
            A.Resize(self.img_size, self.img_size),
            A.Normalize(
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)
            ),
            ToTensorV2()
        ])

    def preprocess(self, image):
        """Препроцессинг изображения"""
        if isinstance(image, bytes):
            image = Image.open(io.BytesIO(image)).convert("RGB")
        elif isinstance(image, np.ndarray):
            image = Image.fromarray(image).convert("RGB")
        else:
            image = image.convert("RGB")

        image = np.array(image)
        transformed = self.transform(image=image)
        return transformed["image"]

    def predict(self, image):
        """Предсказание класса"""
        # Препроцессинг
        tensor = self.preprocess(image)
        tensor = tensor.unsqueeze(0).to(self.device)

        # Предсказание
        with torch.no_grad():
            outputs = self.model(tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            prediction = torch.argmax(probabilities, dim=1)

            # Получаем вероятности для всех классов
            probs = probabilities[0].cpu().numpy()
            class_probs = {self.classes[i]: float(probs[i]) for i in range(len(self.classes))}

        return {
            "predicted_class": self.classes[prediction.item()],
            "confidence": float(probabilities[0][prediction.item()]),
            "all_probabilities": class_probs
        }