| | import timm
|
| | import torch
|
| | import torch.nn as nn
|
| | from torchvision import datasets, transforms
|
| | from torch.utils.data import DataLoader
|
| | from PIL import Image
|
| | import io
|
| | import os
|
| |
|
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | IMG_SIZE = 224
|
| | BATCH_SIZE = 16
|
| | EPOCHS = 1
|
| | MODEL_PATH = "mpox_vit_model_local.pth"
|
| | DATASET_PATH = "Mpox2-1"
|
| |
|
| |
|
| | transform_train = transforms.Compose([
|
| | transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
| | transforms.RandomHorizontalFlip(),
|
| | transforms.RandomRotation(15),
|
| | transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
| | transforms.ToTensor(),
|
| | transforms.Normalize([0.5]*3, [0.5]*3)
|
| | ])
|
| |
|
| | transform_val = transforms.Compose([
|
| | transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
| | transforms.ToTensor(),
|
| | transforms.Normalize([0.5]*3, [0.5]*3)
|
| | ])
|
| |
|
| |
|
| | train_dataset = datasets.ImageFolder(os.path.join(DATASET_PATH, "train"), transform=transform_train)
|
| |
|
| | if os.path.exists(os.path.join(DATASET_PATH, "valid")):
|
| | val_dataset = datasets.ImageFolder(os.path.join(DATASET_PATH, "valid"), transform=transform_val)
|
| | else:
|
| | val_dataset = None
|
| |
|
| | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| | val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE) if val_dataset else None
|
| |
|
| | classes = train_dataset.classes
|
| | num_classes = len(classes)
|
| |
|
| |
|
| | def create_model(pretrained=True):
|
| | model = timm.create_model('vit_base_patch16_224', pretrained=pretrained)
|
| | model.head = nn.Linear(model.head.in_features, num_classes)
|
| | return model
|
| |
|
| |
|
| | def load_or_train_model():
|
| | need_train = True
|
| | if os.path.exists(MODEL_PATH):
|
| | if os.path.getsize(MODEL_PATH) > 1000:
|
| | try:
|
| | model = create_model(pretrained=False)
|
| | model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
| | model.eval()
|
| | model.to(device)
|
| | print(f"Modèle chargé depuis {MODEL_PATH}")
|
| | need_train = False
|
| | return model
|
| | except Exception as e:
|
| | print(f"Erreur au chargement du modèle: {e}")
|
| |
|
| |
|
| | print("Entraînement d'un nouveau modèle...")
|
| | model = create_model(pretrained=True)
|
| | model.to(device)
|
| | criterion = nn.CrossEntropyLoss()
|
| | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
| | model.train()
|
| | for epoch in range(EPOCHS):
|
| | for images, labels in train_loader:
|
| | images, labels = images.to(device), labels.to(device)
|
| | optimizer.zero_grad()
|
| | outputs = model(images)
|
| | loss = criterion(outputs, labels)
|
| | loss.backward()
|
| | optimizer.step()
|
| | torch.save(model.state_dict(), MODEL_PATH)
|
| | print(f"Modèle sauvegardé : {MODEL_PATH}")
|
| | model.eval()
|
| | return model
|
| |
|
| |
|
| |
|
| |
|
| | def predict_image(image_bytes):
|
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| | transform = transforms.Compose([
|
| | transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
| | transforms.ToTensor(),
|
| | transforms.Normalize([0.5]*3, [0.5]*3)
|
| | ])
|
| | x = transform(image).unsqueeze(0).to(device)
|
| | with torch.no_grad():
|
| | outputs = model(x)
|
| | _, predicted = torch.max(outputs, 1)
|
| | return classes[predicted.item()]
|
| |
|
| |
|
| |
|