|
|
--- |
|
|
language: en |
|
|
license: mit |
|
|
tags: |
|
|
- image-classification |
|
|
- resnet |
|
|
- corrosion |
|
|
library_name: pytorch |
|
|
pipeline_tag: image-classification |
|
|
task_categories: |
|
|
- image-classification |
|
|
dataset: custom |
|
|
--- |
|
|
|
|
|
# Corrosion Classifier (ResNet50) |
|
|
|
|
|
This repository contains a ResNet50 image classifier trained to detect corrosion types. |
|
|
|
|
|
## Labels |
|
|
- crevice_corrosion |
|
|
- erosion_corrosion |
|
|
- galvanic_corrosion |
|
|
- mic_corrosion |
|
|
- no_corrosion |
|
|
- pitting_corrosion |
|
|
- stress_corrosion |
|
|
- under_insulation_corrosion |
|
|
- uniform_corrosion |
|
|
|
|
|
## Usage (PyTorch) |
|
|
|
|
|
```python |
|
|
import torch, json |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
import timm |
|
|
|
|
|
# Load labels |
|
|
labels = ['crevice_corrosion', 'erosion_corrosion', 'galvanic_corrosion', 'mic_corrosion', 'no_corrosion', 'pitting_corrosion', 'stress_corrosion', 'under_insulation_corrosion', 'uniform_corrosion'] |
|
|
|
|
|
# Create model |
|
|
model = timm.create_model('resnet50', pretrained=False, num_classes=len(labels)) |
|
|
state = torch.load('resnet50-corrosion-classifier-v1.pth', map_location='cpu') |
|
|
missing, unexpected = model.load_state_dict(state, strict=False) |
|
|
model.eval() |
|
|
|
|
|
# Preprocess (ImageNet) |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
img = Image.open('test.jpg').convert('RGB') |
|
|
x = transform(img).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(x) |
|
|
probs = logits.softmax(dim=1).squeeze().tolist() |
|
|
|
|
|
idx = int(torch.tensor(probs).argmax()) |
|
|
print(labels[idx], probs[idx]) |
|
|
``` |
|
|
|
|
|
> Note: This is a **generic PyTorch checkpoint** (`.pth`). The public Inference API on the Hub does **not** execute arbitrary PyTorch code. If you want to call this model via the **Inference API**, you must convert it to a supported library format (e.g. `transformers` image-classification) or use your existing **Space** and call it via the Gradio API. See below. |
|
|
|
|
|
## Call via your existing Space (recommended now) |
|
|
If your Space works, you can call it programmatically using the Gradio JS Client from Node: |
|
|
```js |
|
|
import { Client, handle_file } from "@gradio/client"; |
|
|
|
|
|
const app = await Client.connect("jacopo22295/RESNET50-CORROSION_CLASSIFIER_V1"); // your Space id |
|
|
const res = await fetch("https://example.com/image.jpg"); |
|
|
const blob = await res.blob(); |
|
|
const out = await app.predict("/predict", [handle_file(blob)]); |
|
|
console.log(out.data); |
|
|
``` |
|
|
|
|
|
## Convert to Transformers (optional, to use Inference API) |
|
|
If you later want to enable the one-click Inference API, consider exporting to a `transformers` ImageClassification model (e.g. `ResNetForImageClassification`) and pushing weights + `preprocessor_config.json`. This requires a small conversion script. |
|
|
|