microsoft/cats_vs_dogs
Viewer • Updated • 23.4k • 5.62k • 63
This model is a simple CNN (Convolutional Neural Network) trained to classify images as either cats or dogs.
The model was trained on the microsoft/cats_vs_dogs dataset from Hugging Face.
import torch
from PIL import Image
from torchvision import transforms
# Load model
model = CatDogClassifier()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
# Prepare image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open('your_image.jpg').convert('RGB')
image_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(image_tensor)
_, predicted = torch.max(outputs.data, 1)
classes = ['Cat', 'Dog']
print(f"Prediction: {classes[predicted.item()]}")
The model achieves approximately 85-90% accuracy on the validation set (results may vary based on training run).
MIT License
Created as an educational project for learning image classification with PyTorch.