🐱🐢 Cats vs Dogs Classification (Transfer Learning CNN | PyTorch)

A PyTorch image classifier for Cats vs Dogs, trained with transfer learning using a pretrained ResNet-18 backbone.

This Hugging Face repo hosts the trained weights + small config files used by the Streamlit inference app.

Training β†’ Model β†’ Inference

What’s in this repo

Stored under artifacts/:

  • model_state.pt β€” PyTorch state_dict for ResNet-18 (final FC trained)
  • config.json β€” image size + normalization stats
  • class_names.json β€” class labels (index order)

Model details (high level)

  • Backbone: ResNet-18 (ImageNet pretrained)
  • Training: backbone frozen, only the final fully-connected layer trained
  • Input size: 224Γ—224 RGB
  • Normalization: ImageNet mean/std (stored in artifacts/config.json)

Inputs

  • One RGB image (any size). At inference time:
    • resize to 256
    • center-crop to 224
    • convert to tensor
    • normalize with mean/std from config.json

Output

  • Predicted label: one of ["cats", "dogs"] (from class_names.json)
  • You can also compute probabilities via softmax.

Quickstart (load + predict)

import json
import torch
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms
from huggingface_hub import hf_hub_download

REPO_ID = "ash001/cats-dogs-transferlearning-cnn"

# Download artifacts
cfg_path   = hf_hub_download(REPO_ID, "artifacts/config.json")
names_path = hf_hub_download(REPO_ID, "artifacts/class_names.json")
pt_path    = hf_hub_download(REPO_ID, "artifacts/model_state.pt")

config = json.load(open(cfg_path, "r"))
class_names = json.load(open(names_path, "r"))

# Rebuild model architecture + load weights
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, int(config["num_classes"]))
state = torch.load(pt_path, map_location="cpu")
model.load_state_dict(state)
model.eval()

# Preprocess
tfm = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(int(config["image_size"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=config["mean"], std=config["std"]),
])

img = Image.open("your_image.jpg").convert("RGB")
x = tfm(img).unsqueeze(0)

with torch.no_grad():
    logits = model(x)
    probs = torch.softmax(logits, dim=1)[0].tolist()

pred_idx = int(torch.argmax(logits, dim=1).item())
print("Prediction:", class_names[pred_idx])
print("Probabilities:", dict(zip(class_names, probs)))

license: apache-2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support