jacopo22295's picture
Upload 9 files
66a8853 verified
metadata
language: en
license: mit
tags:
  - image-classification
  - resnet
  - corrosion
library_name: pytorch
pipeline_tag: image-classification
task_categories:
  - image-classification
dataset: custom

Corrosion Classifier (ResNet50)

This repository contains a ResNet50 image classifier trained to detect corrosion types.

Labels

  • crevice_corrosion
  • erosion_corrosion
  • galvanic_corrosion
  • mic_corrosion
  • no_corrosion
  • pitting_corrosion
  • stress_corrosion
  • under_insulation_corrosion
  • uniform_corrosion

Usage (PyTorch)

import torch, json
from PIL import Image
from torchvision import transforms
import timm

# Load labels
labels = ['crevice_corrosion', 'erosion_corrosion', 'galvanic_corrosion', 'mic_corrosion', 'no_corrosion', 'pitting_corrosion', 'stress_corrosion', 'under_insulation_corrosion', 'uniform_corrosion']

# Create model
model = timm.create_model('resnet50', pretrained=False, num_classes=len(labels))
state = torch.load('resnet50-corrosion-classifier-v1.pth', map_location='cpu')
missing, unexpected = model.load_state_dict(state, strict=False)
model.eval()

# Preprocess (ImageNet)
transform = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img = Image.open('test.jpg').convert('RGB')
x = transform(img).unsqueeze(0)

with torch.no_grad():
    logits = model(x)
    probs = logits.softmax(dim=1).squeeze().tolist()

idx = int(torch.tensor(probs).argmax())
print(labels[idx], probs[idx])

Note: This is a generic PyTorch checkpoint (.pth). The public Inference API on the Hub does not execute arbitrary PyTorch code. If you want to call this model via the Inference API, you must convert it to a supported library format (e.g. transformers image-classification) or use your existing Space and call it via the Gradio API. See below.

Call via your existing Space (recommended now)

If your Space works, you can call it programmatically using the Gradio JS Client from Node:

import { Client, handle_file } from "@gradio/client";

const app = await Client.connect("jacopo22295/RESNET50-CORROSION_CLASSIFIER_V1"); // your Space id
const res = await fetch("https://example.com/image.jpg");
const blob = await res.blob();
const out = await app.predict("/predict", [handle_file(blob)]);
console.log(out.data);

Convert to Transformers (optional, to use Inference API)

If you later want to enable the one-click Inference API, consider exporting to a transformers ImageClassification model (e.g. ResNetForImageClassification) and pushing weights + preprocessor_config.json. This requires a small conversion script.