# -*- coding: utf-8 -*- """vision.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1JriMvbXyr0_2BXST58NUljv9sWWmgbHC """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import models, transforms from PIL import Image import os import numpy as np import time from tqdm import tqdm class Config: seed = 42 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 32 num_workers = 4 learning_rate = 1e-4 num_epochs = 10 num_classes = 2 img_size = 224 def seed_everything(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True seed_everything(Config.seed) print(f"Using device: {Config.device}") # Стандартные статистики ImageNet NORM_MEAN = [0.485, 0.456, 0.406] NORM_STD = [0.229, 0.224, 0.225] def get_transforms(phase='train'): if phase == 'train': return transforms.Compose([ transforms.Resize((256, 256)), # Сначала приводим к общему размеру transforms.RandomResizedCrop(Config.img_size), # Случайный кроп transforms.RandomHorizontalFlip(p=0.5), # Отражение transforms.RandomRotation(degrees=15), # Поворот transforms.ColorJitter(brightness=0.2, contrast=0.2), # Изменение цвета transforms.ToTensor(), transforms.Normalize(NORM_MEAN, NORM_STD) ]) else: return transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(Config.img_size), transforms.ToTensor(), transforms.Normalize(NORM_MEAN, NORM_STD) ]) class CustomDataset(Dataset): def __init__(self, file_paths, labels, transform=None): self.file_paths = file_paths self.labels = labels self.transform = transform def __len__(self): return len(self.file_paths) def __getitem__(self, idx): img_path = self.file_paths[idx] image = Image.open(img_path).convert("RGB") label = self.labels[idx] if self.transform: image = self.transform(image) return image, torch.tensor(label, dtype=torch.long) def build_model(num_classes, pretrained=True): # 1. Загружаем предобученный ResNet18 model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None) # 2. (Опционально) Замораживаем веса бэкбона # Это нужно, если данных мало. Если данных много, можно обучать всё (fine-tuning). for param in model.parameters(): param.requires_grad = False # 3. Заменяем "голову" (полносвязный слой) # model.fc.in_features - это количество входов в оригинальном слое (512 для ResNet18) num_ftrs = model.fc.in_features model.fc = nn.Sequential( nn.Linear(num_ftrs, 256), nn.ReLU(), nn.Dropout(0.5), # Для предотвращения переобучения nn.Linear(256, num_classes) ) return model model = build_model(Config.num_classes).to(Config.device) def train_one_epoch(model, loader, criterion, optimizer, device): model.train() running_loss = 0.0 correct = 0 total = 0 loop = tqdm(loader, leave=True) # Прогресс-бар for images, labels in loop: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loop.set_description(f"Train Loss: {loss.item():.4f}") epoch_loss = running_loss / len(loader) epoch_acc = 100 * correct / total return epoch_loss, epoch_acc def validate(model, loader, criterion, device): model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() epoch_loss = running_loss / len(loader) epoch_acc = 100 * correct / total return epoch_loss, epoch_acc import tempfile fake_data_len = 100 fake_paths = [tempfile.NamedTemporaryFile(suffix='.jpg').name for _ in range(fake_data_len)] for p in fake_paths: Image.new('RGB', (300, 300)).save(p) fake_labels = np.random.randint(0, 2, fake_data_len) # Инициализация датасетов train_dataset = CustomDataset(fake_paths, fake_labels, transform=get_transforms('train')) val_dataset = CustomDataset(fake_paths, fake_labels, transform=get_transforms('val')) # DataLoader'ы train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=0) # num_workers=0 для примера val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0) # Оптимизатор и Лосс # Обучаем только параметры fc (головы), если заморозили бэкбон. # Если не замораживали, передавайте model.parameters() optimizer = optim.Adam(model.fc.parameters(), lr=Config.learning_rate) criterion = nn.CrossEntropyLoss() # Основной цикл best_acc = 0.0 print("Start Training...") for epoch in range(Config.num_epochs): print(f"\nEpoch {epoch+1}/{Config.num_epochs}") train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, Config.device) val_loss, val_acc = validate(model, val_loader, criterion, Config.device) print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), "best_model.pth") print("Model Saved!") import albumentations as A from albumentations.pytorch import ToTensorV2 import cv2 import torch import torch.nn as nn from torchvision import models class AugmentationFactory: """Класс для создания пайплайна аугментаций""" def __init__(self, img_size=224): self.img_size = img_size # Mean и Std для ImageNet (стандарт для предобученных моделей) self.mean = (0.485, 0.456, 0.406) self.std = (0.229, 0.224, 0.225) def get_train_transforms(self): return A.Compose([ A.Resize(height=256, width=256), A.RandomCrop(height=self.img_size, width=self.img_size), # Геометрические аугментации A.HorizontalFlip(p=0.5), A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5), # Цветовые и шумовые аугментации (Albumentations тут очень силен) A.OneOf([ A.GaussNoise(var_limit=(10.0, 50.0)), A.GaussianBlur(), A.MotionBlur(), ], p=0.3), A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3), # Обязательные шаги в конце A.Normalize(mean=self.mean, std=self.std), ToTensorV2() # Конвертирует в torch.Tensor (C, H, W) ]) def get_val_transforms(self): return A.Compose([ A.Resize(height=self.img_size, width=self.img_size), # Или Resize -> CenterCrop A.Normalize(mean=self.mean, std=self.std), ToTensorV2() ]) # Обновленный Dataset под Albumentations class Cv2Dataset(torch.utils.data.Dataset): def __init__(self, file_paths, labels, transforms=None): self.file_paths = file_paths self.labels = labels self.transforms = transforms def __len__(self): return len(self.file_paths) def __getitem__(self, idx): path = self.file_paths[idx] # 1. Читаем через OpenCV (BGR формат по умолчанию) image = cv2.imread(path) # 2. Конвертируем в RGB !!! Очень важно !!! image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 3. Применяем аугментации if self.transforms: # Albumentations возвращает словарь augmented = self.transforms(image=image) image = augmented['image'] label = torch.tensor(self.labels[idx], dtype=torch.long) return image, label import torch.nn as nn class UniversalClassifier(nn.Module): def __init__(self, model_name, num_classes, pretrained=True, freeze_backbone=False): super().__init__() if model_name not in AVAILABLE_BACKBONES: raise ValueError(f"Model {model_name} not found.") full_model = AVAILABLE_BACKBONES[model_name](weights="DEFAULT" if pretrained else None) self.encoder = full_model if freeze_backbone: for param in self.encoder.parameters(): param.requires_grad = False self.head_layer_name = "" if "resnet" in model_name: self.emb_dim = self.encoder.fc.in_features self.encoder.fc = nn.Identity() elif "efficientnet" in model_name: self.emb_dim = self.encoder.classifier[-1].in_features self.encoder.classifier[-1] = nn.Identity() elif "vit" in model_name: self.emb_dim = self.encoder.heads.head.in_features self.encoder.heads.head = nn.Identity() self.head = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(self.emb_dim, num_classes) ) def forward(self, x): features = self.encoder(x) output = self.head(features) return output def get_features(self, x): """Метод специально для получения только эмбеддингов""" return self.encoder(x) AVAILABLE_BACKBONES = { # Тяжелые и точные "resnet50": models.resnet50, "efficientnet_b0": models.efficientnet_b0, # Хороший баланс "efficientnet_b4": models.efficientnet_b4, # Мощнее # Легкие (для мобилок/быстрого инференса) "resnet18": models.resnet18, "mobilenet_v3_large": models.mobilenet_v3_large, # Современные (Transformers) "vit_b_16": models.vit_b_16, # Требует img_size=224 } """# Пример""" # --- КОНФИГУРАЦИЯ --- class Config: model_name = "efficientnet_b0" num_classes = 2 img_size = 224 # EfficientNet_B0 любит 224, B4 любит 380 batch_size = 32 device = "cuda" if torch.cuda.is_available() else "cpu" # 1. Аугментации aug_factory = AugmentationFactory(img_size=Config.img_size) train_transforms = aug_factory.get_train_transforms() # 2. Создание датасета (пример путей) # train_dataset = Cv2Dataset(train_paths, train_labels, transforms=train_transforms) # train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True) # 3. Инициализация модели model = UniversalClassifier( model_name=Config.model_name, num_classes=Config.num_classes, pretrained=True, freeze_backbone=False ).to(Config.device) print(f"Model {Config.model_name} initialized successfully.") dummy_input = torch.randn(2, 3, Config.img_size, Config.img_size).to(Config.device) output = model(dummy_input) print(f"Output shape: {output.shape}") """# Достать эмбединг""" model = UniversalClassifier("resnet18", num_classes=2).to(Config.device) def get_embeddings_clean(model, loader, device): model.eval() embeddings_list = [] with torch.no_grad(): for images, _ in tqdm(loader): images = images.to(device) features = model.get_features(images) embeddings_list.append(features.cpu().numpy()) return np.vstack(embeddings_list) embs = get_embeddings_clean(model, val_loader, Config.device) embs