November-2K25 commited on
Commit
ad3ed30
·
verified ·
1 Parent(s): 37a44cf

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +4 -0
  2. waste_sorting.py +56 -0
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ transformers
waste_sorting.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
6
+ import gradio as gr
7
+
8
+ # Load model and image processor
9
+ model_name = "watersplash/waste-classification" # Change to a valid model
10
+ model = AutoModelForImageClassification.from_pretrained(model_name)
11
+ image_processor = AutoImageProcessor.from_pretrained(model_name)
12
+
13
+ # Define preprocessing function
14
+ def preprocess_image(image):
15
+ transform = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
19
+ ])
20
+ return transform(image).unsqueeze(0) # Add batch dimension
21
+
22
+ # Define multi-label prediction function
23
+ def predict_waste(image):
24
+ image = Image.fromarray(image) # Convert NumPy array to PIL image
25
+ input_tensor = preprocess_image(image)
26
+
27
+ # Get model predictions
28
+ with torch.no_grad():
29
+ outputs = model(input_tensor)
30
+
31
+ # Apply sigmoid activation for multi-label classification
32
+ probabilities = torch.sigmoid(outputs.logits)[0] # Convert logits to probabilities
33
+
34
+ # Set a threshold to select labels (e.g., >= 50%)
35
+ threshold = 0.5
36
+ predicted_labels = [label for idx, label in model.config.id2label.items() if probabilities[idx] >= threshold]
37
+ confidence_scores = [f"{probabilities[idx] * 100:.2f}%" for idx in range(len(probabilities)) if probabilities[idx] >= threshold]
38
+
39
+ if predicted_labels:
40
+ result = "\n".join([f"{label}: {score}" for label, score in zip(predicted_labels, confidence_scores)])
41
+ else:
42
+ result = "No clear classification (confidence below threshold)"
43
+
44
+ return result
45
+
46
+ # Create Gradio interface
47
+ interface = gr.Interface(
48
+ fn=predict_waste,
49
+ inputs=gr.Image(type="numpy"),
50
+ outputs="text",
51
+ title="Multi-Label Waste Sorting App",
52
+ description="Upload an image of waste. The model will classify it into multiple waste categories with confidence scores."
53
+ )
54
+
55
+ # Launch the app
56
+ interface.launch()