Borhan72's picture
Update model.py
6665c98 verified
Raw
History Blame Contribute Delete
3.17 kB
import json
import os
import torch
import torchvision.models as models
from PIL import Image
from torchvision import transforms
DEFAULT_CLASSES = [
"Battery",
"Cardboard",
"Clothes",
"Glass",
"Metal",
"Paper",
"Plastic",
]
class GarbageClassifier:
def __init__(
self,
model_path="best_model.pth",
classes_path="classes.json",
device="cpu",
):
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
self.device = torch.device(device)
self.classes = self._load_classes(classes_path)
class TransferLearningModel(torch.nn.Module):
def __init__(self, num_classes):
super().__init__()
self.backbone = models.resnet18(weights=None)
num_features = self.backbone.fc.in_features
self.backbone.fc = torch.nn.Sequential(
torch.nn.Linear(num_features, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.5),
torch.nn.Linear(256, num_classes)
)
def forward(self, x):
return self.backbone(x)
self.model = TransferLearningModel(
num_classes=len(self.classes)
)
state = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state)
self.model.to(self.device)
self.model.eval()
self.transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
def _load_classes(self, classes_path):
if classes_path and os.path.exists(classes_path):
with open(classes_path, "r", encoding="utf-8") as f:
return json.load(f)
return DEFAULT_CLASSES
def predict(self, image_path):
"""Predict garbage class from image path."""
image = Image.open(image_path).convert("RGB")
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = torch.softmax(outputs, dim=1)
confidence, prediction = torch.max(probabilities, 1)
return {
"class": self.classes[prediction.item()],
"confidence": float(confidence.item()),
"all_probabilities": {
self.classes[i]: float(probabilities[0, i].item())
for i in range(len(self.classes))
},
}
def predict_batch(self, image_paths):
"""Batch prediction for a list of image paths."""
return [self.predict(path) for path in image_paths]
if __name__ == "__main__":
classifier = GarbageClassifier("garbage_model.pth")
result = classifier.predict("test_image.jpg")
print(result)