cervical_lesion / example_inference.py
toderian's picture
Upload folder using huggingface_hub
17daa0b verified
"""
Example inference script for Cervical Cancer Classification model.
Usage:
# From local directory:
python example_inference.py --image path/to/image.jpg --model ./
# From Hugging Face Hub:
python example_inference.py --image path/to/image.jpg --model toderian/cerviguard_lesion
"""
import argparse
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as T
from pathlib import Path
import json
class CervicalCancerCNN(nn.Module):
"""CNN for cervical cancer classification."""
def __init__(self, config=None):
super().__init__()
config = config or {}
conv_channels = config.get("conv_layers", [32, 64, 128, 256])
fc_sizes = config.get("fc_layers", [256, 128])
dropout = config.get("dropout", 0.5)
num_classes = config.get("num_classes", 4)
# Convolutional layers
layers = []
in_channels = 3
for out_channels in conv_channels:
layers.extend([
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
])
in_channels = out_channels
self.conv_layers = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
# FC layers
fc_blocks = []
in_features = conv_channels[-1]
for fc_size in fc_sizes:
fc_blocks.extend([
nn.Linear(in_features, fc_size),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
])
in_features = fc_size
self.fc_layers = nn.Sequential(*fc_blocks)
self.classifier = nn.Linear(in_features, num_classes)
def forward(self, x):
x = self.conv_layers(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
x = self.classifier(x)
return x
def load_model_local(model_dir, device="cpu"):
"""Load model from local directory."""
model_dir = Path(model_dir)
# Load config
config_path = model_dir / "config.json"
config = {}
if config_path.exists():
with open(config_path) as f:
config = json.load(f)
# Create model
model = CervicalCancerCNN(config)
# Load weights
if (model_dir / "model.safetensors").exists():
from safetensors.torch import load_file
state_dict = load_file(str(model_dir / "model.safetensors"))
model.load_state_dict(state_dict)
elif (model_dir / "pytorch_model.bin").exists():
state_dict = torch.load(model_dir / "pytorch_model.bin", map_location=device, weights_only=True)
model.load_state_dict(state_dict)
else:
raise FileNotFoundError(f"No model weights found in {model_dir}")
model.to(device)
model.eval()
return model, config
def load_model_hub(repo_id, device="cpu"):
"""Load model from Hugging Face Hub."""
from huggingface_hub import hf_hub_download, snapshot_download
# Download model files
model_dir = snapshot_download(repo_id=repo_id)
return load_model_local(model_dir, device)
def load_model(model_path, device="cpu"):
"""Load model from local path or Hugging Face Hub."""
model_path = Path(model_path)
if model_path.exists():
return load_model_local(model_path, device)
else:
# Assume it's a Hugging Face repo ID
return load_model_hub(str(model_path), device)
def get_preprocessor(config):
"""Get image preprocessing transform."""
# Get size from config or use defaults
input_size = config.get("input_size", {"height": 224, "width": 298})
height = input_size.get("height", 224)
width = input_size.get("width", 298)
return T.Compose([
T.Resize((height, width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(model, image_tensor, config):
"""Run inference and return prediction."""
# Get label mapping from config
id2label = config.get("id2label", {
"0": "Normal",
"1": "LSIL",
"2": "HSIL",
"3": "Cancer"
})
with torch.no_grad():
output = model(image_tensor)
probabilities = torch.softmax(output, dim=1)[0]
prediction = output.argmax(dim=1).item()
return {
"class_id": prediction,
"class_name": id2label.get(str(prediction), f"Class {prediction}"),
"probabilities": {
id2label.get(str(i), f"Class {i}"): f"{prob:.2%}"
for i, prob in enumerate(probabilities.tolist())
},
"confidence": f"{probabilities[prediction]:.2%}"
}
def main():
parser = argparse.ArgumentParser(description="Cervical Cancer Classification")
parser.add_argument("--image", required=True, help="Path to input image")
parser.add_argument("--model", default="./", help="Path to model dir or HF repo ID")
parser.add_argument("--device", default="cpu", help="Device (cpu/cuda)")
args = parser.parse_args()
print(f"Loading model from {args.model}...")
model, config = load_model(args.model, args.device)
print(f"Processing image: {args.image}")
transform = get_preprocessor(config)
image = Image.open(args.image).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(args.device)
result = predict(model, image_tensor, config)
print("\n" + "=" * 50)
print("PREDICTION RESULT")
print("=" * 50)
print(f"Class: {result['class_name']}")
print(f"Confidence: {result['confidence']}")
print("\nAll probabilities:")
for cls, prob in result['probabilities'].items():
print(f" {cls}: {prob}")
if __name__ == "__main__":
main()