| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| import timm | |
| from transformers import ViTFeatureExtractor, ViTForImageClassification | |
| from pathlib import Path | |
| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| from sklearn.model_selection import train_test_split | |
| from tqdm.auto import tqdm | |
| import wandb | |
| class PlantDiseaseDataset(Dataset): | |
| def __init__(self, image_paths, labels, transform=None): | |
| self.image_paths = image_paths | |
| self.labels = labels | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image_path = self.image_paths[idx] | |
| image = Image.open(image_path).convert('RGB') | |
| label = self.labels[idx] | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| class PlantDiseaseClassifier: | |
| def __init__(self, data_dir, model_name='vit_base_patch16_224', num_classes=38): | |
| self.data_dir = Path(data_dir) | |
| self.model_name = model_name | |
| self.num_classes = num_classes | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Initialize wandb | |
| wandb.init(project="plant-disease-classification") | |
| def prepare_data(self): | |
| """Prepare dataset and create data loaders""" | |
| # Data augmentation and normalization for training | |
| train_transform = transforms.Compose([ | |
| transforms.RandomResizedCrop(224), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomVerticalFlip(), | |
| transforms.RandomRotation(20), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Just normalization for validation/testing | |
| val_transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Collect all image paths and labels | |
| image_paths = [] | |
| labels = [] | |
| self.class_to_idx = {} | |
| for idx, class_dir in enumerate(sorted(self.data_dir.glob('*'))): | |
| if class_dir.is_dir(): | |
| self.class_to_idx[class_dir.name] = idx | |
| for img_path in class_dir.glob('*.jpg'): | |
| image_paths.append(str(img_path)) | |
| labels.append(idx) | |
| # Split data | |
| train_paths, val_paths, train_labels, val_labels = train_test_split( | |
| image_paths, labels, test_size=0.2, stratify=labels, random_state=42 | |
| ) | |
| # Create datasets | |
| train_dataset = PlantDiseaseDataset(train_paths, train_labels, train_transform) | |
| val_dataset = PlantDiseaseDataset(val_paths, val_labels, val_transform) | |
| # Create data loaders | |
| self.train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) | |
| self.val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4) | |
| return self.train_loader, self.val_loader | |
| def create_model(self): | |
| """Initialize the Vision Transformer model""" | |
| self.model = timm.create_model( | |
| self.model_name, | |
| pretrained=True, | |
| num_classes=self.num_classes | |
| ) | |
| self.model = self.model.to(self.device) | |
| # Loss function and optimizer | |
| self.criterion = nn.CrossEntropyLoss() | |
| self.optimizer = torch.optim.AdamW( | |
| self.model.parameters(), | |
| lr=2e-5, | |
| weight_decay=0.01 | |
| ) | |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| self.optimizer, | |
| T_max=10 | |
| ) | |
| return self.model | |
| def train_epoch(self, epoch): | |
| """Train for one epoch""" | |
| self.model.train() | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}') | |
| for batch_idx, (inputs, targets) in enumerate(progress_bar): | |
| inputs, targets = inputs.to(self.device), targets.to(self.device) | |
| self.optimizer.zero_grad() | |
| outputs = self.model(inputs) | |
| loss = self.criterion(outputs, targets) | |
| loss.backward() | |
| self.optimizer.step() | |
| total_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| total += targets.size(0) | |
| correct += predicted.eq(targets).sum().item() | |
| progress_bar.set_postfix({ | |
| 'Loss': total_loss/(batch_idx+1), | |
| 'Acc': 100.*correct/total | |
| }) | |
| # Log to wandb | |
| wandb.log({ | |
| 'train_loss': loss.item(), | |
| 'train_acc': 100.*correct/total | |
| }) | |
| return total_loss/len(self.train_loader), 100.*correct/total | |
| def validate(self): | |
| """Validate the model""" | |
| self.model.eval() | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for inputs, targets in tqdm(self.val_loader, desc='Validating'): | |
| inputs, targets = inputs.to(self.device), targets.to(self.device) | |
| outputs = self.model(inputs) | |
| loss = self.criterion(outputs, targets) | |
| total_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| total += targets.size(0) | |
| correct += predicted.eq(targets).sum().item() | |
| accuracy = 100.*correct/total | |
| avg_loss = total_loss/len(self.val_loader) | |
| # Log to wandb | |
| wandb.log({ | |
| 'val_loss': avg_loss, | |
| 'val_acc': accuracy | |
| }) | |
| return avg_loss, accuracy | |
| def train(self, epochs=10): | |
| """Complete training process""" | |
| best_acc = 0 | |
| for epoch in range(epochs): | |
| train_loss, train_acc = self.train_epoch(epoch) | |
| val_loss, val_acc = self.validate() | |
| self.scheduler.step() | |
| print(f'\nEpoch {epoch}:') | |
| print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%') | |
| print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%') | |
| # Save best model | |
| if val_acc > best_acc: | |
| best_acc = val_acc | |
| torch.save({ | |
| 'model_state_dict': self.model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'class_to_idx': self.class_to_idx | |
| }, 'best_model.pth') | |
| wandb.finish() | |
| def save_for_huggingface(self): | |
| """Save model in Hugging Face format""" | |
| # Load best model | |
| checkpoint = torch.load('best_model.pth') | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| # Save model and config | |
| self.model.save_pretrained('plant_disease_model') | |
| # Save class mapping | |
| idx_to_class = {v: k for k, v in self.class_to_idx.items()} | |
| pd.Series(idx_to_class).to_json('class_mapping.json') | |
| if __name__ == "__main__": | |
| classifier = PlantDiseaseClassifier(data_dir="path/to/dataset") | |
| classifier.prepare_data() | |
| classifier.create_model() | |
| classifier.train(epochs=10) | |
| classifier.save_for_huggingface() |