Corrosion Classifier (ResNet50)
This repository contains a ResNet50 image classifier trained to detect corrosion types.
Labels
- crevice_corrosion
- erosion_corrosion
- galvanic_corrosion
- mic_corrosion
- no_corrosion
- pitting_corrosion
- stress_corrosion
- under_insulation_corrosion
- uniform_corrosion
Usage (PyTorch)
import torch, json
from PIL import Image
from torchvision import transforms
import timm
# Load labels
labels = ['crevice_corrosion', 'erosion_corrosion', 'galvanic_corrosion', 'mic_corrosion', 'no_corrosion', 'pitting_corrosion', 'stress_corrosion', 'under_insulation_corrosion', 'uniform_corrosion']
# Create model
model = timm.create_model('resnet50', pretrained=False, num_classes=len(labels))
state = torch.load('resnet50-corrosion-classifier-v1.pth', map_location='cpu')
missing, unexpected = model.load_state_dict(state, strict=False)
model.eval()
# Preprocess (ImageNet)
transform = transforms.Compose([
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = Image.open('test.jpg').convert('RGB')
x = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(x)
probs = logits.softmax(dim=1).squeeze().tolist()
idx = int(torch.tensor(probs).argmax())
print(labels[idx], probs[idx])
Note: This is a generic PyTorch checkpoint (
.pth). The public Inference API on the Hub does not execute arbitrary PyTorch code. If you want to call this model via the Inference API, you must convert it to a supported library format (e.g.transformersimage-classification) or use your existing Space and call it via the Gradio API. See below.
Call via your existing Space (recommended now)
If your Space works, you can call it programmatically using the Gradio JS Client from Node:
import { Client, handle_file } from "@gradio/client";
const app = await Client.connect("jacopo22295/RESNET50-CORROSION_CLASSIFIER_V1"); // your Space id
const res = await fetch("https://example.com/image.jpg");
const blob = await res.blob();
const out = await app.predict("/predict", [handle_file(blob)]);
console.log(out.data);
Convert to Transformers (optional, to use Inference API)
If you later want to enable the one-click Inference API, consider exporting to a transformers ImageClassification model (e.g. ResNetForImageClassification) and pushing weights + preprocessor_config.json. This requires a small conversion script.