Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import os | |
| from src.model import TrashNetClassifier | |
| from src import config | |
| def load_model(model_path, device): | |
| model = TrashNetClassifier() | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| model.to(device) | |
| return model | |
| def preprocess_image(image_path, image_size): | |
| transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3, [0.5]*3) | |
| ]) | |
| image = Image.open(image_path).convert("RGB") | |
| return transform(image).unsqueeze(0) | |
| def predict_image(model, image_tensor, class_names, device): | |
| image_tensor = image_tensor.to(device) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| probs = torch.softmax(outputs, dim=1) | |
| pred_index = torch.argmax(probs, dim=1).item() | |
| pred_label = class_names[pred_index] | |
| confidence = probs[0][pred_index].item() | |
| return pred_label, confidence | |
| def run_inference(image_path): | |
| device = config.DEVICE | |
| class_names = sorted(os.listdir(os.path.join(config.DATA_DIR, "train"))) | |
| model = load_model(config.MODEL_SAVE_PATH, device) | |
| image_tensor = preprocess_image(image_path, config.IMAGE_SIZE) | |
| label, confidence = predict_image(model, image_tensor, class_names, device) | |
| print(f"Prediction: {label} ({confidence*100:.2f}%)") | |
| return label, confidence | |