Spaces:
Sleeping
Sleeping
| import argparse | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from model import load_model | |
| def load_labels(labels_path: str = "labels.json") -> dict[int, str]: | |
| labels_file = Path(labels_path) | |
| if not labels_file.exists(): | |
| raise FileNotFoundError(f"labels.json not found at: {labels_file}") | |
| with labels_file.open("r", encoding="utf-8") as f: | |
| raw = json.load(f) | |
| return {int(k): v for k, v in raw.items()} | |
| def preprocess_image(image_path: str) -> torch.Tensor: | |
| img_file = Path(image_path) | |
| if not img_file.exists(): | |
| raise FileNotFoundError(f"Image not found at: {img_file}") | |
| img = Image.open(img_file).convert("RGB") | |
| img = img.resize((28, 28)) | |
| arr = np.array(img).astype("float32") / 255.0 # [H, W, C] in [0,1] | |
| arr = np.transpose(arr, (2, 0, 1)) # [C, H, W] | |
| tensor = torch.from_numpy(arr).unsqueeze(0) # [1, 3, 28, 28] | |
| return tensor | |
| def predict( | |
| image_path: str, | |
| weights_path: str = "model.pth", | |
| labels_path: str = "labels.json" | |
| ): | |
| model, device = load_model(weights_path) | |
| id2label = load_labels(labels_path) | |
| x = preprocess_image(image_path).to(device) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| pred_idx = int(torch.argmax(probs).item()) | |
| pred_label = id2label.get(pred_idx, str(pred_idx)) | |
| probs_list = probs.cpu().tolist() | |
| return pred_idx, pred_label, probs_list | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Run inference with SkinCNN on a dermatoscopic image." | |
| ) | |
| parser.add_argument("image", type=str, help="Path to input dermatoscopic image.") | |
| parser.add_argument( | |
| "--weights", | |
| type=str, | |
| default="model.pth", | |
| help="Path to model weights (.pth).", | |
| ) | |
| parser.add_argument( | |
| "--labels", | |
| type=str, | |
| default="labels.json", | |
| help="Path to labels.json.", | |
| ) | |
| args = parser.parse_args() | |
| idx, label, probs = predict( | |
| image_path=args.image, | |
| weights_path=args.weights, | |
| labels_path=args.labels, | |
| ) | |
| print(f"Predicted class index: {idx}") | |
| print(f"Predicted label : {label}") | |
| print(f"Probabilities : {probs}") | |
| if __name__ == "__main__": | |
| main() | |