November-2K25 commited on
Commit
37a44cf
·
verified ·
1 Parent(s): 0a329c9

Delete waste_sorting.py

Browse files
Files changed (1) hide show
  1. waste_sorting.py +0 -56
waste_sorting.py DELETED
@@ -1,56 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torchvision import transforms
4
- from PIL import Image
5
- from transformers import AutoModelForImageClassification, AutoImageProcessor
6
- import gradio as gr
7
-
8
- # Load model and image processor
9
- model_name = "watersplash/waste-classification" # Change to a valid model
10
- model = AutoModelForImageClassification.from_pretrained(model_name)
11
- image_processor = AutoImageProcessor.from_pretrained(model_name)
12
-
13
- # Define preprocessing function
14
- def preprocess_image(image):
15
- transform = transforms.Compose([
16
- transforms.Resize((224, 224)),
17
- transforms.ToTensor(),
18
- transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
19
- ])
20
- return transform(image).unsqueeze(0) # Add batch dimension
21
-
22
- # Define multi-label prediction function
23
- def predict_waste(image):
24
- image = Image.fromarray(image) # Convert NumPy array to PIL image
25
- input_tensor = preprocess_image(image)
26
-
27
- # Get model predictions
28
- with torch.no_grad():
29
- outputs = model(input_tensor)
30
-
31
- # Apply sigmoid activation for multi-label classification
32
- probabilities = torch.sigmoid(outputs.logits)[0] # Convert logits to probabilities
33
-
34
- # Set a threshold to select labels (e.g., >= 50%)
35
- threshold = 0.5
36
- predicted_labels = [label for idx, label in model.config.id2label.items() if probabilities[idx] >= threshold]
37
- confidence_scores = [f"{probabilities[idx] * 100:.2f}%" for idx in range(len(probabilities)) if probabilities[idx] >= threshold]
38
-
39
- if predicted_labels:
40
- result = "\n".join([f"{label}: {score}" for label, score in zip(predicted_labels, confidence_scores)])
41
- else:
42
- result = "No clear classification (confidence below threshold)"
43
-
44
- return result
45
-
46
- # Create Gradio interface
47
- interface = gr.Interface(
48
- fn=predict_waste,
49
- inputs=gr.Image(type="numpy"),
50
- outputs="text",
51
- title="Multi-Label Waste Sorting App",
52
- description="Upload an image of waste. The model will classify it into multiple waste categories with confidence scores."
53
- )
54
-
55
- # Launch the app
56
- interface.launch()