import json import os import torch import torchvision.models as models from PIL import Image from torchvision import transforms DEFAULT_CLASSES = [ "Battery", "Cardboard", "Clothes", "Glass", "Metal", "Paper", "Plastic", ] class GarbageClassifier: def __init__( self, model_path="best_model.pth", classes_path="classes.json", device="cpu", ): if device == "cuda" and not torch.cuda.is_available(): device = "cpu" self.device = torch.device(device) self.classes = self._load_classes(classes_path) class TransferLearningModel(torch.nn.Module): def __init__(self, num_classes): super().__init__() self.backbone = models.resnet18(weights=None) num_features = self.backbone.fc.in_features self.backbone.fc = torch.nn.Sequential( torch.nn.Linear(num_features, 256), torch.nn.ReLU(), torch.nn.Dropout(0.5), torch.nn.Linear(256, num_classes) ) def forward(self, x): return self.backbone(x) self.model = TransferLearningModel( num_classes=len(self.classes) ) state = torch.load(model_path, map_location=self.device) self.model.load_state_dict(state) self.model.to(self.device) self.model.eval() self.transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ] ) def _load_classes(self, classes_path): if classes_path and os.path.exists(classes_path): with open(classes_path, "r", encoding="utf-8") as f: return json.load(f) return DEFAULT_CLASSES def predict(self, image_path): """Predict garbage class from image path.""" image = Image.open(image_path).convert("RGB") image_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(image_tensor) probabilities = torch.softmax(outputs, dim=1) confidence, prediction = torch.max(probabilities, 1) return { "class": self.classes[prediction.item()], "confidence": float(confidence.item()), "all_probabilities": { self.classes[i]: float(probabilities[0, i].item()) for i in range(len(self.classes)) }, } def predict_batch(self, image_paths): """Batch prediction for a list of image paths.""" return [self.predict(path) for path in image_paths] if __name__ == "__main__": classifier = GarbageClassifier("garbage_model.pth") result = classifier.predict("test_image.jpg") print(result)