YOLO_Waste_Vision_System / inference.py
padmanabhbosamia's picture
Upload inference.py
8944108 verified
"""
Inference Script for Waste Classification
Supports single image, batch processing, and real-time inference
"""
import torch
from ultralytics import YOLO
from pathlib import Path
import argparse
import cv2
import numpy as np
from PIL import Image
import json
class WasteClassifier:
"""Wrapper class for waste classification inference"""
def __init__(self, model_path):
"""
Initialize the classifier
Args:
model_path: Path to trained model weights (.pt file)
"""
self.model = YOLO(model_path)
self.class_names = self.model.names
print(f"Model loaded from: {model_path}")
print(f"Classes: {list(self.class_names.values())}")
def predict(self, image_path, conf_threshold=0.25, return_image=False):
"""
Predict waste class for a single image
Args:
image_path: Path to image file
conf_threshold: Confidence threshold
return_image: Whether to return annotated image
Returns:
dict with predictions
"""
results = self.model(image_path, conf=conf_threshold)
# Get top prediction
result = results[0]
probs = result.probs
# Get top class
top1_idx = probs.top1
top1_conf = probs.top1conf.item()
top1_class = self.class_names[top1_idx]
# Get all class probabilities
all_probs = {}
for idx, class_name in self.class_names.items():
all_probs[class_name] = probs.data[idx].item()
prediction = {
'class': top1_class,
'confidence': top1_conf,
'all_probabilities': all_probs,
'class_index': top1_idx
}
if return_image:
# Get annotated image
annotated_img = result.plot()
prediction['annotated_image'] = annotated_img
return prediction
def predict_batch(self, image_paths, conf_threshold=0.25):
"""
Predict for multiple images
Args:
image_paths: List of image paths
conf_threshold: Confidence threshold
Returns:
List of prediction dictionaries
"""
results = self.model(image_paths, conf=conf_threshold)
predictions = []
for result in results:
probs = result.probs
top1_idx = probs.top1
top1_conf = probs.top1conf.item()
top1_class = self.class_names[top1_idx]
all_probs = {}
for idx, class_name in self.class_names.items():
all_probs[class_name] = probs.data[idx].item()
predictions.append({
'class': top1_class,
'confidence': top1_conf,
'all_probabilities': all_probs,
'class_index': top1_idx
})
return predictions
def predict_from_array(self, image_array, conf_threshold=0.25):
"""
Predict from numpy array (for real-time inference)
Args:
image_array: numpy array image (BGR or RGB)
conf_threshold: Confidence threshold
Returns:
dict with predictions
"""
results = self.model(image_array, conf=conf_threshold)
result = results[0]
probs = result.probs
top1_idx = probs.top1
top1_conf = probs.top1conf.item()
top1_class = self.class_names[top1_idx]
all_probs = {}
for idx, class_name in self.class_names.items():
all_probs[class_name] = probs.data[idx].item()
return {
'class': top1_class,
'confidence': top1_conf,
'all_probabilities': all_probs,
'class_index': top1_idx
}
def main():
parser = argparse.ArgumentParser(description="Waste Classification Inference")
parser.add_argument("--model", type=str, required=True,
help="Path to trained model weights (.pt file)")
parser.add_argument("--source", type=str, required=True,
help="Path to image file or directory")
parser.add_argument("--conf", type=float, default=0.25,
help="Confidence threshold")
parser.add_argument("--output", type=str, default="predictions",
help="Output directory for results")
parser.add_argument("--save-txt", action="store_true",
help="Save predictions to text file")
parser.add_argument("--save-json", action="store_true",
help="Save predictions to JSON file")
args = parser.parse_args()
# Initialize classifier
classifier = WasteClassifier(args.model)
# Process source
source_path = Path(args.source)
output_path = Path(args.output)
output_path.mkdir(parents=True, exist_ok=True)
if source_path.is_file():
# Single image
print(f"Processing image: {source_path}")
prediction = classifier.predict(str(source_path), args.conf, return_image=True)
print("\n" + "="*50)
print("Prediction Results:")
print("="*50)
print(f"Class: {prediction['class']}")
print(f"Confidence: {prediction['confidence']:.4f}")
print("\nAll Probabilities:")
for class_name, prob in sorted(prediction['all_probabilities'].items(),
key=lambda x: x[1], reverse=True):
print(f" {class_name}: {prob:.4f}")
print("="*50)
# Save annotated image
if 'annotated_image' in prediction:
output_img_path = output_path / f"predicted_{source_path.name}"
cv2.imwrite(str(output_img_path), prediction['annotated_image'])
print(f"\nAnnotated image saved to: {output_img_path}")
# Save results
if args.save_txt:
txt_path = output_path / f"prediction_{source_path.stem}.txt"
with open(txt_path, 'w') as f:
f.write(f"Class: {prediction['class']}\n")
f.write(f"Confidence: {prediction['confidence']:.4f}\n\n")
f.write("All Probabilities:\n")
for class_name, prob in sorted(prediction['all_probabilities'].items(),
key=lambda x: x[1], reverse=True):
f.write(f"{class_name}: {prob:.4f}\n")
if args.save_json:
json_path = output_path / f"prediction_{source_path.stem}.json"
with open(json_path, 'w') as f:
json.dump(prediction, f, indent=2)
elif source_path.is_dir():
# Batch processing
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
image_paths = [p for p in source_path.iterdir()
if p.suffix.lower() in image_extensions]
print(f"Processing {len(image_paths)} images...")
predictions = classifier.predict_batch([str(p) for p in image_paths], args.conf)
# Save results
results = []
for img_path, pred in zip(image_paths, predictions):
results.append({
'image': str(img_path),
'prediction': pred
})
if args.save_json:
json_path = output_path / "batch_predictions.json"
with open(json_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to: {json_path}")
# Print summary
print("\n" + "="*50)
print("Batch Prediction Summary:")
print("="*50)
class_counts = {}
for pred in predictions:
cls = pred['class']
class_counts[cls] = class_counts.get(cls, 0) + 1
for cls, count in sorted(class_counts.items()):
print(f"{cls}: {count} images")
print("="*50)
else:
print(f"Error: Source path does not exist: {source_path}")
if __name__ == "__main__":
main()