|
|
""" |
|
|
Example inference script for Cervical Cancer Classification model. |
|
|
|
|
|
Usage: |
|
|
# From local directory: |
|
|
python example_inference.py --image path/to/image.jpg --model ./ |
|
|
|
|
|
# From Hugging Face Hub: |
|
|
python example_inference.py --image path/to/image.jpg --model toderian/cerviguard_lesion |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from PIL import Image |
|
|
import torchvision.transforms as T |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
|
|
|
class CervicalCancerCNN(nn.Module): |
|
|
"""CNN for cervical cancer classification.""" |
|
|
|
|
|
def __init__(self, config=None): |
|
|
super().__init__() |
|
|
|
|
|
config = config or {} |
|
|
conv_channels = config.get("conv_layers", [32, 64, 128, 256]) |
|
|
fc_sizes = config.get("fc_layers", [256, 128]) |
|
|
dropout = config.get("dropout", 0.5) |
|
|
num_classes = config.get("num_classes", 4) |
|
|
|
|
|
|
|
|
layers = [] |
|
|
in_channels = 3 |
|
|
for out_channels in conv_channels: |
|
|
layers.extend([ |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
]) |
|
|
in_channels = out_channels |
|
|
|
|
|
self.conv_layers = nn.Sequential(*layers) |
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
|
|
|
fc_blocks = [] |
|
|
in_features = conv_channels[-1] |
|
|
for fc_size in fc_sizes: |
|
|
fc_blocks.extend([ |
|
|
nn.Linear(in_features, fc_size), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(dropout), |
|
|
]) |
|
|
in_features = fc_size |
|
|
|
|
|
self.fc_layers = nn.Sequential(*fc_blocks) |
|
|
self.classifier = nn.Linear(in_features, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv_layers(x) |
|
|
x = self.avgpool(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
x = self.fc_layers(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def load_model_local(model_dir, device="cpu"): |
|
|
"""Load model from local directory.""" |
|
|
model_dir = Path(model_dir) |
|
|
|
|
|
|
|
|
config_path = model_dir / "config.json" |
|
|
config = {} |
|
|
if config_path.exists(): |
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
model = CervicalCancerCNN(config) |
|
|
|
|
|
|
|
|
if (model_dir / "model.safetensors").exists(): |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(str(model_dir / "model.safetensors")) |
|
|
model.load_state_dict(state_dict) |
|
|
elif (model_dir / "pytorch_model.bin").exists(): |
|
|
state_dict = torch.load(model_dir / "pytorch_model.bin", map_location=device, weights_only=True) |
|
|
model.load_state_dict(state_dict) |
|
|
else: |
|
|
raise FileNotFoundError(f"No model weights found in {model_dir}") |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model, config |
|
|
|
|
|
|
|
|
def load_model_hub(repo_id, device="cpu"): |
|
|
"""Load model from Hugging Face Hub.""" |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
|
|
|
|
model_dir = snapshot_download(repo_id=repo_id) |
|
|
return load_model_local(model_dir, device) |
|
|
|
|
|
|
|
|
def load_model(model_path, device="cpu"): |
|
|
"""Load model from local path or Hugging Face Hub.""" |
|
|
model_path = Path(model_path) |
|
|
|
|
|
if model_path.exists(): |
|
|
return load_model_local(model_path, device) |
|
|
else: |
|
|
|
|
|
return load_model_hub(str(model_path), device) |
|
|
|
|
|
|
|
|
def get_preprocessor(config): |
|
|
"""Get image preprocessing transform.""" |
|
|
|
|
|
input_size = config.get("input_size", {"height": 224, "width": 298}) |
|
|
height = input_size.get("height", 224) |
|
|
width = input_size.get("width", 298) |
|
|
|
|
|
return T.Compose([ |
|
|
T.Resize((height, width)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
def predict(model, image_tensor, config): |
|
|
"""Run inference and return prediction.""" |
|
|
|
|
|
id2label = config.get("id2label", { |
|
|
"0": "Normal", |
|
|
"1": "LSIL", |
|
|
"2": "HSIL", |
|
|
"3": "Cancer" |
|
|
}) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(image_tensor) |
|
|
probabilities = torch.softmax(output, dim=1)[0] |
|
|
prediction = output.argmax(dim=1).item() |
|
|
|
|
|
return { |
|
|
"class_id": prediction, |
|
|
"class_name": id2label.get(str(prediction), f"Class {prediction}"), |
|
|
"probabilities": { |
|
|
id2label.get(str(i), f"Class {i}"): f"{prob:.2%}" |
|
|
for i, prob in enumerate(probabilities.tolist()) |
|
|
}, |
|
|
"confidence": f"{probabilities[prediction]:.2%}" |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Cervical Cancer Classification") |
|
|
parser.add_argument("--image", required=True, help="Path to input image") |
|
|
parser.add_argument("--model", default="./", help="Path to model dir or HF repo ID") |
|
|
parser.add_argument("--device", default="cpu", help="Device (cpu/cuda)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"Loading model from {args.model}...") |
|
|
model, config = load_model(args.model, args.device) |
|
|
|
|
|
print(f"Processing image: {args.image}") |
|
|
transform = get_preprocessor(config) |
|
|
image = Image.open(args.image).convert('RGB') |
|
|
image_tensor = transform(image).unsqueeze(0).to(args.device) |
|
|
|
|
|
result = predict(model, image_tensor, config) |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print("PREDICTION RESULT") |
|
|
print("=" * 50) |
|
|
print(f"Class: {result['class_name']}") |
|
|
print(f"Confidence: {result['confidence']}") |
|
|
print("\nAll probabilities:") |
|
|
for cls, prob in result['probabilities'].items(): |
|
|
print(f" {cls}: {prob}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|