""" Example inference script for Cervical Cancer Classification model. Usage: # From local directory: python example_inference.py --image path/to/image.jpg --model ./ # From Hugging Face Hub: python example_inference.py --image path/to/image.jpg --model toderian/cerviguard_lesion """ import argparse import torch import torch.nn as nn from PIL import Image import torchvision.transforms as T from pathlib import Path import json class CervicalCancerCNN(nn.Module): """CNN for cervical cancer classification.""" def __init__(self, config=None): super().__init__() config = config or {} conv_channels = config.get("conv_layers", [32, 64, 128, 256]) fc_sizes = config.get("fc_layers", [256, 128]) dropout = config.get("dropout", 0.5) num_classes = config.get("num_classes", 4) # Convolutional layers layers = [] in_channels = 3 for out_channels in conv_channels: layers.extend([ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), ]) in_channels = out_channels self.conv_layers = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) # FC layers fc_blocks = [] in_features = conv_channels[-1] for fc_size in fc_sizes: fc_blocks.extend([ nn.Linear(in_features, fc_size), nn.ReLU(inplace=True), nn.Dropout(dropout), ]) in_features = fc_size self.fc_layers = nn.Sequential(*fc_blocks) self.classifier = nn.Linear(in_features, num_classes) def forward(self, x): x = self.conv_layers(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc_layers(x) x = self.classifier(x) return x def load_model_local(model_dir, device="cpu"): """Load model from local directory.""" model_dir = Path(model_dir) # Load config config_path = model_dir / "config.json" config = {} if config_path.exists(): with open(config_path) as f: config = json.load(f) # Create model model = CervicalCancerCNN(config) # Load weights if (model_dir / "model.safetensors").exists(): from safetensors.torch import load_file state_dict = load_file(str(model_dir / "model.safetensors")) model.load_state_dict(state_dict) elif (model_dir / "pytorch_model.bin").exists(): state_dict = torch.load(model_dir / "pytorch_model.bin", map_location=device, weights_only=True) model.load_state_dict(state_dict) else: raise FileNotFoundError(f"No model weights found in {model_dir}") model.to(device) model.eval() return model, config def load_model_hub(repo_id, device="cpu"): """Load model from Hugging Face Hub.""" from huggingface_hub import hf_hub_download, snapshot_download # Download model files model_dir = snapshot_download(repo_id=repo_id) return load_model_local(model_dir, device) def load_model(model_path, device="cpu"): """Load model from local path or Hugging Face Hub.""" model_path = Path(model_path) if model_path.exists(): return load_model_local(model_path, device) else: # Assume it's a Hugging Face repo ID return load_model_hub(str(model_path), device) def get_preprocessor(config): """Get image preprocessing transform.""" # Get size from config or use defaults input_size = config.get("input_size", {"height": 224, "width": 298}) height = input_size.get("height", 224) width = input_size.get("width", 298) return T.Compose([ T.Resize((height, width)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict(model, image_tensor, config): """Run inference and return prediction.""" # Get label mapping from config id2label = config.get("id2label", { "0": "Normal", "1": "LSIL", "2": "HSIL", "3": "Cancer" }) with torch.no_grad(): output = model(image_tensor) probabilities = torch.softmax(output, dim=1)[0] prediction = output.argmax(dim=1).item() return { "class_id": prediction, "class_name": id2label.get(str(prediction), f"Class {prediction}"), "probabilities": { id2label.get(str(i), f"Class {i}"): f"{prob:.2%}" for i, prob in enumerate(probabilities.tolist()) }, "confidence": f"{probabilities[prediction]:.2%}" } def main(): parser = argparse.ArgumentParser(description="Cervical Cancer Classification") parser.add_argument("--image", required=True, help="Path to input image") parser.add_argument("--model", default="./", help="Path to model dir or HF repo ID") parser.add_argument("--device", default="cpu", help="Device (cpu/cuda)") args = parser.parse_args() print(f"Loading model from {args.model}...") model, config = load_model(args.model, args.device) print(f"Processing image: {args.image}") transform = get_preprocessor(config) image = Image.open(args.image).convert('RGB') image_tensor = transform(image).unsqueeze(0).to(args.device) result = predict(model, image_tensor, config) print("\n" + "=" * 50) print("PREDICTION RESULT") print("=" * 50) print(f"Class: {result['class_name']}") print(f"Confidence: {result['confidence']}") print("\nAll probabilities:") for cls, prob in result['probabilities'].items(): print(f" {cls}: {prob}") if __name__ == "__main__": main()