vision_cv / vision.py
Funmagster's picture
Add vision.py
52b24ff verified
# -*- coding: utf-8 -*-
"""vision.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1JriMvbXyr0_2BXST58NUljv9sWWmgbHC
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import os
import numpy as np
import time
from tqdm import tqdm
class Config:
seed = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
num_workers = 4
learning_rate = 1e-4
num_epochs = 10
num_classes = 2
img_size = 224
def seed_everything(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
seed_everything(Config.seed)
print(f"Using device: {Config.device}")
# Стандартные статистики ImageNet
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD = [0.229, 0.224, 0.225]
def get_transforms(phase='train'):
if phase == 'train':
return transforms.Compose([
transforms.Resize((256, 256)), # Сначала приводим к общему размеру
transforms.RandomResizedCrop(Config.img_size), # Случайный кроп
transforms.RandomHorizontalFlip(p=0.5), # Отражение
transforms.RandomRotation(degrees=15), # Поворот
transforms.ColorJitter(brightness=0.2, contrast=0.2), # Изменение цвета
transforms.ToTensor(),
transforms.Normalize(NORM_MEAN, NORM_STD)
])
else:
return transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(Config.img_size),
transforms.ToTensor(),
transforms.Normalize(NORM_MEAN, NORM_STD)
])
class CustomDataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
img_path = self.file_paths[idx]
image = Image.open(img_path).convert("RGB")
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, torch.tensor(label, dtype=torch.long)
def build_model(num_classes, pretrained=True):
# 1. Загружаем предобученный ResNet18
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
# 2. (Опционально) Замораживаем веса бэкбона
# Это нужно, если данных мало. Если данных много, можно обучать всё (fine-tuning).
for param in model.parameters():
param.requires_grad = False
# 3. Заменяем "голову" (полносвязный слой)
# model.fc.in_features - это количество входов в оригинальном слое (512 для ResNet18)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_ftrs, 256),
nn.ReLU(),
nn.Dropout(0.5), # Для предотвращения переобучения
nn.Linear(256, num_classes)
)
return model
model = build_model(Config.num_classes).to(Config.device)
def train_one_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
loop = tqdm(loader, leave=True) # Прогресс-бар
for images, labels in loop:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loop.set_description(f"Train Loss: {loss.item():.4f}")
epoch_loss = running_loss / len(loader)
epoch_acc = 100 * correct / total
return epoch_loss, epoch_acc
def validate(model, loader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_loss = running_loss / len(loader)
epoch_acc = 100 * correct / total
return epoch_loss, epoch_acc
import tempfile
fake_data_len = 100
fake_paths = [tempfile.NamedTemporaryFile(suffix='.jpg').name for _ in range(fake_data_len)]
for p in fake_paths:
Image.new('RGB', (300, 300)).save(p)
fake_labels = np.random.randint(0, 2, fake_data_len)
# Инициализация датасетов
train_dataset = CustomDataset(fake_paths, fake_labels, transform=get_transforms('train'))
val_dataset = CustomDataset(fake_paths, fake_labels, transform=get_transforms('val'))
# DataLoader'ы
train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=0) # num_workers=0 для примера
val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0)
# Оптимизатор и Лосс
# Обучаем только параметры fc (головы), если заморозили бэкбон.
# Если не замораживали, передавайте model.parameters()
optimizer = optim.Adam(model.fc.parameters(), lr=Config.learning_rate)
criterion = nn.CrossEntropyLoss()
# Основной цикл
best_acc = 0.0
print("Start Training...")
for epoch in range(Config.num_epochs):
print(f"\nEpoch {epoch+1}/{Config.num_epochs}")
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, Config.device)
val_loss, val_acc = validate(model, val_loader, criterion, Config.device)
print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_model.pth")
print("Model Saved!")
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import torch
import torch.nn as nn
from torchvision import models
class AugmentationFactory:
"""Класс для создания пайплайна аугментаций"""
def __init__(self, img_size=224):
self.img_size = img_size
# Mean и Std для ImageNet (стандарт для предобученных моделей)
self.mean = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)
def get_train_transforms(self):
return A.Compose([
A.Resize(height=256, width=256),
A.RandomCrop(height=self.img_size, width=self.img_size),
# Геометрические аугментации
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
# Цветовые и шумовые аугментации (Albumentations тут очень силен)
A.OneOf([
A.GaussNoise(var_limit=(10.0, 50.0)),
A.GaussianBlur(),
A.MotionBlur(),
], p=0.3),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3),
# Обязательные шаги в конце
A.Normalize(mean=self.mean, std=self.std),
ToTensorV2() # Конвертирует в torch.Tensor (C, H, W)
])
def get_val_transforms(self):
return A.Compose([
A.Resize(height=self.img_size, width=self.img_size), # Или Resize -> CenterCrop
A.Normalize(mean=self.mean, std=self.std),
ToTensorV2()
])
# Обновленный Dataset под Albumentations
class Cv2Dataset(torch.utils.data.Dataset):
def __init__(self, file_paths, labels, transforms=None):
self.file_paths = file_paths
self.labels = labels
self.transforms = transforms
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
path = self.file_paths[idx]
# 1. Читаем через OpenCV (BGR формат по умолчанию)
image = cv2.imread(path)
# 2. Конвертируем в RGB !!! Очень важно !!!
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 3. Применяем аугментации
if self.transforms:
# Albumentations возвращает словарь
augmented = self.transforms(image=image)
image = augmented['image']
label = torch.tensor(self.labels[idx], dtype=torch.long)
return image, label
import torch.nn as nn
class UniversalClassifier(nn.Module):
def __init__(self, model_name, num_classes, pretrained=True, freeze_backbone=False):
super().__init__()
if model_name not in AVAILABLE_BACKBONES:
raise ValueError(f"Model {model_name} not found.")
full_model = AVAILABLE_BACKBONES[model_name](weights="DEFAULT" if pretrained else None)
self.encoder = full_model
if freeze_backbone:
for param in self.encoder.parameters():
param.requires_grad = False
self.head_layer_name = ""
if "resnet" in model_name:
self.emb_dim = self.encoder.fc.in_features
self.encoder.fc = nn.Identity()
elif "efficientnet" in model_name:
self.emb_dim = self.encoder.classifier[-1].in_features
self.encoder.classifier[-1] = nn.Identity()
elif "vit" in model_name:
self.emb_dim = self.encoder.heads.head.in_features
self.encoder.heads.head = nn.Identity()
self.head = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(self.emb_dim, num_classes)
)
def forward(self, x):
features = self.encoder(x)
output = self.head(features)
return output
def get_features(self, x):
"""Метод специально для получения только эмбеддингов"""
return self.encoder(x)
AVAILABLE_BACKBONES = {
# Тяжелые и точные
"resnet50": models.resnet50,
"efficientnet_b0": models.efficientnet_b0, # Хороший баланс
"efficientnet_b4": models.efficientnet_b4, # Мощнее
# Легкие (для мобилок/быстрого инференса)
"resnet18": models.resnet18,
"mobilenet_v3_large": models.mobilenet_v3_large,
# Современные (Transformers)
"vit_b_16": models.vit_b_16, # Требует img_size=224
}
"""# Пример"""
# --- КОНФИГУРАЦИЯ ---
class Config:
model_name = "efficientnet_b0"
num_classes = 2
img_size = 224 # EfficientNet_B0 любит 224, B4 любит 380
batch_size = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Аугментации
aug_factory = AugmentationFactory(img_size=Config.img_size)
train_transforms = aug_factory.get_train_transforms()
# 2. Создание датасета (пример путей)
# train_dataset = Cv2Dataset(train_paths, train_labels, transforms=train_transforms)
# train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
# 3. Инициализация модели
model = UniversalClassifier(
model_name=Config.model_name,
num_classes=Config.num_classes,
pretrained=True,
freeze_backbone=False
).to(Config.device)
print(f"Model {Config.model_name} initialized successfully.")
dummy_input = torch.randn(2, 3, Config.img_size, Config.img_size).to(Config.device)
output = model(dummy_input)
print(f"Output shape: {output.shape}")
"""# Достать эмбединг"""
model = UniversalClassifier("resnet18", num_classes=2).to(Config.device)
def get_embeddings_clean(model, loader, device):
model.eval()
embeddings_list = []
with torch.no_grad():
for images, _ in tqdm(loader):
images = images.to(device)
features = model.get_features(images)
embeddings_list.append(features.cpu().numpy())
return np.vstack(embeddings_list)
embs = get_embeddings_clean(model, val_loader, Config.device)
embs