ma4389 commited on
Commit
4f3a751
·
verified ·
1 Parent(s): fb6aa1c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +55 -0
  2. best_model.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms, models
5
+ from PIL import Image
6
+
7
+ # Define classes
8
+ classes = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
9
+
10
+ # Load model architecture
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+
13
+ base_model = models.resnet50(weights=None) # No pretrained weights here since we load ours
14
+ in_features = base_model.fc.in_features
15
+ base_model.fc = nn.Sequential(
16
+ nn.Linear(in_features, 512),
17
+ nn.ReLU(),
18
+ nn.Dropout(0.4),
19
+ nn.Linear(512, len(classes))
20
+ )
21
+
22
+ # Load trained weights
23
+ base_model.load_state_dict(torch.load("best_model.pth", map_location=device))
24
+ base_model = base_model.to(device)
25
+ base_model.eval()
26
+
27
+ # Define transforms (same as training but without augmentation)
28
+ transform = transforms.Compose([
29
+ transforms.Lambda(lambda x: x.convert('RGB')),
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.5, 0.5, 0.5],
33
+ std=[0.5, 0.5, 0.5])
34
+ ])
35
+
36
+ # Prediction function
37
+ def predict(img):
38
+ img = transform(img).unsqueeze(0).to(device)
39
+ with torch.no_grad():
40
+ outputs = base_model(img)
41
+ probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
42
+ # Create a dictionary of class probabilities
43
+ return {classes[i]: float(probs[i]) for i in range(len(classes))}
44
+
45
+ # Create Gradio interface
46
+ demo = gr.Interface(
47
+ fn=predict,
48
+ inputs=gr.Image(type="pil"),
49
+ outputs=gr.Label(num_top_classes=6),
50
+ title="Garbage Classification",
51
+ description="Upload an image of garbage and the model will predict its type (cardboard, glass, metal, paper, plastic, trash)."
52
+ )
53
+
54
+ if __name__ == "__main__":
55
+ demo.launch(debug=True)
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fef4553274b5cbbb7cb757187b196095cb7bd28365d8a5d57b07a1f3877314d
3
+ size 98560064
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow