File size: 3,617 Bytes
32bb98c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import timm
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
import io
import os

# --- CONFIGURATION ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 1  # Pour test rapide
MODEL_PATH = "mpox_vit_model_local.pth"
DATASET_PATH = "Mpox2-1"

# --- TRANSFORMATIONS ---
transform_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

transform_val = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# --- DATASET LOCAL ---
train_dataset = datasets.ImageFolder(os.path.join(DATASET_PATH, "train"), transform=transform_train)

if os.path.exists(os.path.join(DATASET_PATH, "valid")):
    val_dataset = datasets.ImageFolder(os.path.join(DATASET_PATH, "valid"), transform=transform_val)
else:
    val_dataset = None

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE) if val_dataset else None

classes = train_dataset.classes
num_classes = len(classes)

# --- MODELE ---
def create_model(pretrained=True):
    model = timm.create_model('vit_base_patch16_224', pretrained=pretrained)
    model.head = nn.Linear(model.head.in_features, num_classes)
    return model

# --- CHARGEMENT DU MODELE ---
def load_or_train_model():
    need_train = True
    if os.path.exists(MODEL_PATH):
        if os.path.getsize(MODEL_PATH) > 1000:  # Vérifie que le fichier n'est pas vide
            try:
                model = create_model(pretrained=False)
                model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
                model.eval()
                model.to(device)
                print(f"Modèle chargé depuis {MODEL_PATH}")
                need_train = False
                return model
            except Exception as e:
                print(f"Erreur au chargement du modèle: {e}")

    # Entraînement si le modèle n'existe pas ou est corrompu
    print("Entraînement d'un nouveau modèle...")
    model = create_model(pretrained=True)
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    model.train()
    for epoch in range(EPOCHS):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    torch.save(model.state_dict(), MODEL_PATH)
    print(f"Modèle sauvegardé : {MODEL_PATH}")
    model.eval()
    return model

# model = load_or_train_model()

# --- PRÉDICTION ---
def predict_image(image_bytes):
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    x = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(x)
        _, predicted = torch.max(outputs, 1)
    return classes[predicted.item()]