Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torchvision import datasets, models, transforms | |
| import os | |
| # 1. Setup Data | |
| data_dir = './animals-10' # Path to your Kaggle dataset | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| dataset = datasets.ImageFolder(data_dir, transform=transform) | |
| train_size = int(0.8 * len(dataset)) | |
| val_size = len(dataset) - train_size | |
| train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size]) | |
| train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True) | |
| val_loader = torch.utils.data.DataLoader(val_ds, batch_size=32) | |
| # 2. Modify Model (Fine-tuning) | |
| model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) | |
| num_ftrs = model.fc.in_features | |
| # Change the output layer from 1000 classes to your 10 animals | |
| model.fc = nn.Linear(num_ftrs, 10) | |
| # 3. Training Loop (Simplified) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=0.0001) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| print("Starting training...") | |
| for epoch in range(5): # Adjust epochs as needed | |
| model.train() | |
| for images, labels in train_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Epoch {epoch+1} complete.") | |
| # 4. Save the model weights | |
| torch.save(model.state_dict(), 'animal_model.pth') | |
| # Save the class names to keep track of the index-to-Italian mapping | |
| with open('classes.txt', 'w') as f: | |
| for cls in dataset.classes: | |
| f.write(cls + '\n') |