ericssonish commited on
Commit
32bb98c
·
verified ·
1 Parent(s): 71120a0

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +106 -0
models.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import datasets, transforms
5
+ from torch.utils.data import DataLoader
6
+ from PIL import Image
7
+ import io
8
+ import os
9
+
10
+ # --- CONFIGURATION ---
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ IMG_SIZE = 224
13
+ BATCH_SIZE = 16
14
+ EPOCHS = 1 # Pour test rapide
15
+ MODEL_PATH = "mpox_vit_model_local.pth"
16
+ DATASET_PATH = "Mpox2-1"
17
+
18
+ # --- TRANSFORMATIONS ---
19
+ transform_train = transforms.Compose([
20
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
21
+ transforms.RandomHorizontalFlip(),
22
+ transforms.RandomRotation(15),
23
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize([0.5]*3, [0.5]*3)
26
+ ])
27
+
28
+ transform_val = transforms.Compose([
29
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.5]*3, [0.5]*3)
32
+ ])
33
+
34
+ # --- DATASET LOCAL ---
35
+ train_dataset = datasets.ImageFolder(os.path.join(DATASET_PATH, "train"), transform=transform_train)
36
+
37
+ if os.path.exists(os.path.join(DATASET_PATH, "valid")):
38
+ val_dataset = datasets.ImageFolder(os.path.join(DATASET_PATH, "valid"), transform=transform_val)
39
+ else:
40
+ val_dataset = None
41
+
42
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
43
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE) if val_dataset else None
44
+
45
+ classes = train_dataset.classes
46
+ num_classes = len(classes)
47
+
48
+ # --- MODELE ---
49
+ def create_model(pretrained=True):
50
+ model = timm.create_model('vit_base_patch16_224', pretrained=pretrained)
51
+ model.head = nn.Linear(model.head.in_features, num_classes)
52
+ return model
53
+
54
+ # --- CHARGEMENT DU MODELE ---
55
+ def load_or_train_model():
56
+ need_train = True
57
+ if os.path.exists(MODEL_PATH):
58
+ if os.path.getsize(MODEL_PATH) > 1000: # Vérifie que le fichier n'est pas vide
59
+ try:
60
+ model = create_model(pretrained=False)
61
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
62
+ model.eval()
63
+ model.to(device)
64
+ print(f"Modèle chargé depuis {MODEL_PATH}")
65
+ need_train = False
66
+ return model
67
+ except Exception as e:
68
+ print(f"Erreur au chargement du modèle: {e}")
69
+
70
+ # Entraînement si le modèle n'existe pas ou est corrompu
71
+ print("Entraînement d'un nouveau modèle...")
72
+ model = create_model(pretrained=True)
73
+ model.to(device)
74
+ criterion = nn.CrossEntropyLoss()
75
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
76
+ model.train()
77
+ for epoch in range(EPOCHS):
78
+ for images, labels in train_loader:
79
+ images, labels = images.to(device), labels.to(device)
80
+ optimizer.zero_grad()
81
+ outputs = model(images)
82
+ loss = criterion(outputs, labels)
83
+ loss.backward()
84
+ optimizer.step()
85
+ torch.save(model.state_dict(), MODEL_PATH)
86
+ print(f"Modèle sauvegardé : {MODEL_PATH}")
87
+ model.eval()
88
+ return model
89
+
90
+ # model = load_or_train_model()
91
+
92
+ # --- PRÉDICTION ---
93
+ def predict_image(image_bytes):
94
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
95
+ transform = transforms.Compose([
96
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
97
+ transforms.ToTensor(),
98
+ transforms.Normalize([0.5]*3, [0.5]*3)
99
+ ])
100
+ x = transform(image).unsqueeze(0).to(device)
101
+ with torch.no_grad():
102
+ outputs = model(x)
103
+ _, predicted = torch.max(outputs, 1)
104
+ return classes[predicted.item()]
105
+
106
+