Funmagster commited on
Commit
52b24ff
·
verified ·
1 Parent(s): 4d40330

Add vision.py

Browse files
Files changed (1) hide show
  1. vision.py +376 -0
vision.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """vision.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1JriMvbXyr0_2BXST58NUljv9sWWmgbHC
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from torchvision import models, transforms
15
+ from PIL import Image
16
+ import os
17
+ import numpy as np
18
+ import time
19
+ from tqdm import tqdm
20
+
21
+ class Config:
22
+ seed = 42
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ batch_size = 32
25
+ num_workers = 4
26
+ learning_rate = 1e-4
27
+ num_epochs = 10
28
+ num_classes = 2
29
+ img_size = 224
30
+
31
+ def seed_everything(seed):
32
+ np.random.seed(seed)
33
+ torch.manual_seed(seed)
34
+ torch.cuda.manual_seed(seed)
35
+ torch.backends.cudnn.deterministic = True
36
+
37
+ seed_everything(Config.seed)
38
+ print(f"Using device: {Config.device}")
39
+
40
+ # Стандартные статистики ImageNet
41
+ NORM_MEAN = [0.485, 0.456, 0.406]
42
+ NORM_STD = [0.229, 0.224, 0.225]
43
+
44
+ def get_transforms(phase='train'):
45
+ if phase == 'train':
46
+ return transforms.Compose([
47
+ transforms.Resize((256, 256)), # Сначала приводим к общему размеру
48
+ transforms.RandomResizedCrop(Config.img_size), # Случайный кроп
49
+ transforms.RandomHorizontalFlip(p=0.5), # Отражение
50
+ transforms.RandomRotation(degrees=15), # Поворот
51
+ transforms.ColorJitter(brightness=0.2, contrast=0.2), # Изменение цвета
52
+ transforms.ToTensor(),
53
+ transforms.Normalize(NORM_MEAN, NORM_STD)
54
+ ])
55
+ else:
56
+ return transforms.Compose([
57
+ transforms.Resize((256, 256)),
58
+ transforms.CenterCrop(Config.img_size),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(NORM_MEAN, NORM_STD)
61
+ ])
62
+
63
+ class CustomDataset(Dataset):
64
+ def __init__(self, file_paths, labels, transform=None):
65
+ self.file_paths = file_paths
66
+ self.labels = labels
67
+ self.transform = transform
68
+
69
+ def __len__(self):
70
+ return len(self.file_paths)
71
+
72
+ def __getitem__(self, idx):
73
+ img_path = self.file_paths[idx]
74
+ image = Image.open(img_path).convert("RGB")
75
+ label = self.labels[idx]
76
+
77
+ if self.transform:
78
+ image = self.transform(image)
79
+
80
+ return image, torch.tensor(label, dtype=torch.long)
81
+
82
+ def build_model(num_classes, pretrained=True):
83
+ # 1. Загружаем предобученный ResNet18
84
+ model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
85
+
86
+ # 2. (Опционально) Замораживаем веса бэкбона
87
+ # Это нужно, если данных мало. Если данных много, можно обучать всё (fine-tuning).
88
+ for param in model.parameters():
89
+ param.requires_grad = False
90
+
91
+ # 3. Заменяем "голову" (полносвязный слой)
92
+ # model.fc.in_features - это количество входов в оригинальном слое (512 для ResNet18)
93
+ num_ftrs = model.fc.in_features
94
+
95
+ model.fc = nn.Sequential(
96
+ nn.Linear(num_ftrs, 256),
97
+ nn.ReLU(),
98
+ nn.Dropout(0.5), # Для предотвращения переобучения
99
+ nn.Linear(256, num_classes)
100
+ )
101
+
102
+ return model
103
+
104
+ model = build_model(Config.num_classes).to(Config.device)
105
+
106
+ def train_one_epoch(model, loader, criterion, optimizer, device):
107
+ model.train()
108
+ running_loss = 0.0
109
+ correct = 0
110
+ total = 0
111
+
112
+ loop = tqdm(loader, leave=True) # Прогресс-бар
113
+
114
+ for images, labels in loop:
115
+ images, labels = images.to(device), labels.to(device)
116
+
117
+ optimizer.zero_grad()
118
+ outputs = model(images)
119
+ loss = criterion(outputs, labels)
120
+ loss.backward()
121
+
122
+ optimizer.step()
123
+
124
+ running_loss += loss.item()
125
+ _, predicted = torch.max(outputs.data, 1)
126
+ total += labels.size(0)
127
+ correct += (predicted == labels).sum().item()
128
+
129
+ loop.set_description(f"Train Loss: {loss.item():.4f}")
130
+
131
+ epoch_loss = running_loss / len(loader)
132
+ epoch_acc = 100 * correct / total
133
+ return epoch_loss, epoch_acc
134
+
135
+ def validate(model, loader, criterion, device):
136
+ model.eval()
137
+ running_loss = 0.0
138
+ correct = 0
139
+ total = 0
140
+
141
+ with torch.no_grad():
142
+ for images, labels in loader:
143
+ images, labels = images.to(device), labels.to(device)
144
+
145
+ outputs = model(images)
146
+ loss = criterion(outputs, labels)
147
+
148
+ running_loss += loss.item()
149
+ _, predicted = torch.max(outputs.data, 1)
150
+ total += labels.size(0)
151
+ correct += (predicted == labels).sum().item()
152
+
153
+ epoch_loss = running_loss / len(loader)
154
+ epoch_acc = 100 * correct / total
155
+ return epoch_loss, epoch_acc
156
+
157
+ import tempfile
158
+ fake_data_len = 100
159
+ fake_paths = [tempfile.NamedTemporaryFile(suffix='.jpg').name for _ in range(fake_data_len)]
160
+ for p in fake_paths:
161
+ Image.new('RGB', (300, 300)).save(p)
162
+ fake_labels = np.random.randint(0, 2, fake_data_len)
163
+
164
+ # Инициализация датасетов
165
+ train_dataset = CustomDataset(fake_paths, fake_labels, transform=get_transforms('train'))
166
+ val_dataset = CustomDataset(fake_paths, fake_labels, transform=get_transforms('val'))
167
+
168
+ # DataLoader'ы
169
+ train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=0) # num_workers=0 для примера
170
+ val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0)
171
+
172
+ # Оптимизатор и Лосс
173
+ # Обучаем только параметры fc (головы), если заморозили бэкбон.
174
+ # Если не замораживали, передавайте model.parameters()
175
+ optimizer = optim.Adam(model.fc.parameters(), lr=Config.learning_rate)
176
+ criterion = nn.CrossEntropyLoss()
177
+
178
+ # Основной цикл
179
+ best_acc = 0.0
180
+
181
+ print("Start Training...")
182
+ for epoch in range(Config.num_epochs):
183
+ print(f"\nEpoch {epoch+1}/{Config.num_epochs}")
184
+
185
+ train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, Config.device)
186
+ val_loss, val_acc = validate(model, val_loader, criterion, Config.device)
187
+
188
+ print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
189
+ print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
190
+
191
+ if val_acc > best_acc:
192
+ best_acc = val_acc
193
+ torch.save(model.state_dict(), "best_model.pth")
194
+ print("Model Saved!")
195
+
196
+ import albumentations as A
197
+ from albumentations.pytorch import ToTensorV2
198
+ import cv2
199
+ import torch
200
+ import torch.nn as nn
201
+ from torchvision import models
202
+
203
+ class AugmentationFactory:
204
+ """Класс для создания пайплайна аугментаций"""
205
+ def __init__(self, img_size=224):
206
+ self.img_size = img_size
207
+
208
+ # Mean и Std для ImageNet (стандарт для предобученных моделей)
209
+ self.mean = (0.485, 0.456, 0.406)
210
+ self.std = (0.229, 0.224, 0.225)
211
+
212
+ def get_train_transforms(self):
213
+ return A.Compose([
214
+ A.Resize(height=256, width=256),
215
+ A.RandomCrop(height=self.img_size, width=self.img_size),
216
+
217
+ # Геометрические аугментации
218
+ A.HorizontalFlip(p=0.5),
219
+ A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
220
+
221
+ # Цветовые и шумовые аугментации (Albumentations тут очень силен)
222
+ A.OneOf([
223
+ A.GaussNoise(var_limit=(10.0, 50.0)),
224
+ A.GaussianBlur(),
225
+ A.MotionBlur(),
226
+ ], p=0.3),
227
+ A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3),
228
+
229
+ # Обязательные шаги в конце
230
+ A.Normalize(mean=self.mean, std=self.std),
231
+ ToTensorV2() # Конвертирует в torch.Tensor (C, H, W)
232
+ ])
233
+
234
+ def get_val_transforms(self):
235
+ return A.Compose([
236
+ A.Resize(height=self.img_size, width=self.img_size), # Или Resize -> CenterCrop
237
+ A.Normalize(mean=self.mean, std=self.std),
238
+ ToTensorV2()
239
+ ])
240
+
241
+ # Обновленный Dataset под Albumentations
242
+ class Cv2Dataset(torch.utils.data.Dataset):
243
+ def __init__(self, file_paths, labels, transforms=None):
244
+ self.file_paths = file_paths
245
+ self.labels = labels
246
+ self.transforms = transforms
247
+
248
+ def __len__(self):
249
+ return len(self.file_paths)
250
+
251
+ def __getitem__(self, idx):
252
+ path = self.file_paths[idx]
253
+
254
+ # 1. Читаем через OpenCV (BGR формат по умолчанию)
255
+ image = cv2.imread(path)
256
+ # 2. Конвертируем в RGB !!! Очень важно !!!
257
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
258
+
259
+ # 3. Применяем аугментации
260
+ if self.transforms:
261
+ # Albumentations возвращает словарь
262
+ augmented = self.transforms(image=image)
263
+ image = augmented['image']
264
+
265
+ label = torch.tensor(self.labels[idx], dtype=torch.long)
266
+ return image, label
267
+
268
+ import torch.nn as nn
269
+
270
+ class UniversalClassifier(nn.Module):
271
+ def __init__(self, model_name, num_classes, pretrained=True, freeze_backbone=False):
272
+ super().__init__()
273
+
274
+ if model_name not in AVAILABLE_BACKBONES:
275
+ raise ValueError(f"Model {model_name} not found.")
276
+
277
+ full_model = AVAILABLE_BACKBONES[model_name](weights="DEFAULT" if pretrained else None)
278
+ self.encoder = full_model
279
+
280
+ if freeze_backbone:
281
+ for param in self.encoder.parameters():
282
+ param.requires_grad = False
283
+
284
+ self.head_layer_name = ""
285
+
286
+ if "resnet" in model_name:
287
+ self.emb_dim = self.encoder.fc.in_features
288
+ self.encoder.fc = nn.Identity()
289
+
290
+ elif "efficientnet" in model_name:
291
+ self.emb_dim = self.encoder.classifier[-1].in_features
292
+ self.encoder.classifier[-1] = nn.Identity()
293
+
294
+ elif "vit" in model_name:
295
+ self.emb_dim = self.encoder.heads.head.in_features
296
+ self.encoder.heads.head = nn.Identity()
297
+
298
+ self.head = nn.Sequential(
299
+ nn.Dropout(p=0.3),
300
+ nn.Linear(self.emb_dim, num_classes)
301
+ )
302
+
303
+ def forward(self, x):
304
+ features = self.encoder(x)
305
+ output = self.head(features)
306
+ return output
307
+
308
+ def get_features(self, x):
309
+ """Метод специально для получения только эмбеддингов"""
310
+ return self.encoder(x)
311
+
312
+ AVAILABLE_BACKBONES = {
313
+ # Тяжелые и точные
314
+ "resnet50": models.resnet50,
315
+ "efficientnet_b0": models.efficientnet_b0, # Хороший баланс
316
+ "efficientnet_b4": models.efficientnet_b4, # Мощнее
317
+
318
+ # Легкие (для мобилок/быстрого инференса)
319
+ "resnet18": models.resnet18,
320
+ "mobilenet_v3_large": models.mobilenet_v3_large,
321
+
322
+ # Современные (Transformers)
323
+ "vit_b_16": models.vit_b_16, # Требует img_size=224
324
+ }
325
+
326
+ """# Пример"""
327
+
328
+ # --- КОНФИГУРАЦИЯ ---
329
+ class Config:
330
+ model_name = "efficientnet_b0"
331
+ num_classes = 2
332
+ img_size = 224 # EfficientNet_B0 любит 224, B4 любит 380
333
+ batch_size = 32
334
+ device = "cuda" if torch.cuda.is_available() else "cpu"
335
+
336
+ # 1. Аугментации
337
+ aug_factory = AugmentationFactory(img_size=Config.img_size)
338
+ train_transforms = aug_factory.get_train_transforms()
339
+
340
+ # 2. Создание датасета (пример путей)
341
+ # train_dataset = Cv2Dataset(train_paths, train_labels, transforms=train_transforms)
342
+ # train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
343
+
344
+ # 3. Инициализация модели
345
+ model = UniversalClassifier(
346
+ model_name=Config.model_name,
347
+ num_classes=Config.num_classes,
348
+ pretrained=True,
349
+ freeze_backbone=False
350
+ ).to(Config.device)
351
+
352
+ print(f"Model {Config.model_name} initialized successfully.")
353
+ dummy_input = torch.randn(2, 3, Config.img_size, Config.img_size).to(Config.device)
354
+ output = model(dummy_input)
355
+ print(f"Output shape: {output.shape}")
356
+
357
+ """# Достать эмбединг"""
358
+
359
+ model = UniversalClassifier("resnet18", num_classes=2).to(Config.device)
360
+
361
+ def get_embeddings_clean(model, loader, device):
362
+ model.eval()
363
+ embeddings_list = []
364
+
365
+ with torch.no_grad():
366
+ for images, _ in tqdm(loader):
367
+ images = images.to(device)
368
+ features = model.get_features(images)
369
+ embeddings_list.append(features.cpu().numpy())
370
+
371
+ return np.vstack(embeddings_list)
372
+
373
+ embs = get_embeddings_clean(model, val_loader, Config.device)
374
+
375
+ embs
376
+