prediction / model.py
ivanm151's picture
init comit
547626f
import torch
import torch.nn as nn
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from PIL import Image
import io
CLASSES = ["clear", "acne", "ros", "black"]
IMG_SIZE = 224
class SkinClassifier:
def __init__(self, model_path="model/stage1_skin_classifier.pth"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.classes = CLASSES
self.img_size = IMG_SIZE
# Инициализируем модель
self.model = timm.create_model(
"efficientnet_b0",
pretrained=False,
num_classes=len(self.classes)
)
# Загружаем веса
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
# Трансформации
self.transform = A.Compose([
A.Resize(self.img_size, self.img_size),
A.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
),
ToTensorV2()
])
def preprocess(self, image):
"""Препроцессинг изображения"""
if isinstance(image, bytes):
image = Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
image = np.array(image)
transformed = self.transform(image=image)
return transformed["image"]
def predict(self, image):
"""Предсказание класса"""
# Препроцессинг
tensor = self.preprocess(image)
tensor = tensor.unsqueeze(0).to(self.device)
# Предсказание
with torch.no_grad():
outputs = self.model(tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
prediction = torch.argmax(probabilities, dim=1)
# Получаем вероятности для всех классов
probs = probabilities[0].cpu().numpy()
class_probs = {self.classes[i]: float(probs[i]) for i in range(len(self.classes))}
return {
"predicted_class": self.classes[prediction.item()],
"confidence": float(probabilities[0][prediction.item()]),
"all_probabilities": class_probs
}