| """ |
| Inference Script for Waste Classification |
| Supports single image, batch processing, and real-time inference |
| """ |
| import torch |
| from ultralytics import YOLO |
| from pathlib import Path |
| import argparse |
| import cv2 |
| import numpy as np |
| from PIL import Image |
| import json |
|
|
| class WasteClassifier: |
| """Wrapper class for waste classification inference""" |
| |
| def __init__(self, model_path): |
| """ |
| Initialize the classifier |
| |
| Args: |
| model_path: Path to trained model weights (.pt file) |
| """ |
| self.model = YOLO(model_path) |
| self.class_names = self.model.names |
| print(f"Model loaded from: {model_path}") |
| print(f"Classes: {list(self.class_names.values())}") |
| |
| def predict(self, image_path, conf_threshold=0.25, return_image=False): |
| """ |
| Predict waste class for a single image |
| |
| Args: |
| image_path: Path to image file |
| conf_threshold: Confidence threshold |
| return_image: Whether to return annotated image |
| |
| Returns: |
| dict with predictions |
| """ |
| results = self.model(image_path, conf=conf_threshold) |
| |
| |
| result = results[0] |
| probs = result.probs |
| |
| |
| top1_idx = probs.top1 |
| top1_conf = probs.top1conf.item() |
| top1_class = self.class_names[top1_idx] |
| |
| |
| all_probs = {} |
| for idx, class_name in self.class_names.items(): |
| all_probs[class_name] = probs.data[idx].item() |
| |
| prediction = { |
| 'class': top1_class, |
| 'confidence': top1_conf, |
| 'all_probabilities': all_probs, |
| 'class_index': top1_idx |
| } |
| |
| if return_image: |
| |
| annotated_img = result.plot() |
| prediction['annotated_image'] = annotated_img |
| |
| return prediction |
| |
| def predict_batch(self, image_paths, conf_threshold=0.25): |
| """ |
| Predict for multiple images |
| |
| Args: |
| image_paths: List of image paths |
| conf_threshold: Confidence threshold |
| |
| Returns: |
| List of prediction dictionaries |
| """ |
| results = self.model(image_paths, conf=conf_threshold) |
| |
| predictions = [] |
| for result in results: |
| probs = result.probs |
| top1_idx = probs.top1 |
| top1_conf = probs.top1conf.item() |
| top1_class = self.class_names[top1_idx] |
| |
| all_probs = {} |
| for idx, class_name in self.class_names.items(): |
| all_probs[class_name] = probs.data[idx].item() |
| |
| predictions.append({ |
| 'class': top1_class, |
| 'confidence': top1_conf, |
| 'all_probabilities': all_probs, |
| 'class_index': top1_idx |
| }) |
| |
| return predictions |
| |
| def predict_from_array(self, image_array, conf_threshold=0.25): |
| """ |
| Predict from numpy array (for real-time inference) |
| |
| Args: |
| image_array: numpy array image (BGR or RGB) |
| conf_threshold: Confidence threshold |
| |
| Returns: |
| dict with predictions |
| """ |
| results = self.model(image_array, conf=conf_threshold) |
| result = results[0] |
| probs = result.probs |
| |
| top1_idx = probs.top1 |
| top1_conf = probs.top1conf.item() |
| top1_class = self.class_names[top1_idx] |
| |
| all_probs = {} |
| for idx, class_name in self.class_names.items(): |
| all_probs[class_name] = probs.data[idx].item() |
| |
| return { |
| 'class': top1_class, |
| 'confidence': top1_conf, |
| 'all_probabilities': all_probs, |
| 'class_index': top1_idx |
| } |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Waste Classification Inference") |
| parser.add_argument("--model", type=str, required=True, |
| help="Path to trained model weights (.pt file)") |
| parser.add_argument("--source", type=str, required=True, |
| help="Path to image file or directory") |
| parser.add_argument("--conf", type=float, default=0.25, |
| help="Confidence threshold") |
| parser.add_argument("--output", type=str, default="predictions", |
| help="Output directory for results") |
| parser.add_argument("--save-txt", action="store_true", |
| help="Save predictions to text file") |
| parser.add_argument("--save-json", action="store_true", |
| help="Save predictions to JSON file") |
| |
| args = parser.parse_args() |
| |
| |
| classifier = WasteClassifier(args.model) |
| |
| |
| source_path = Path(args.source) |
| output_path = Path(args.output) |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| if source_path.is_file(): |
| |
| print(f"Processing image: {source_path}") |
| prediction = classifier.predict(str(source_path), args.conf, return_image=True) |
| |
| print("\n" + "="*50) |
| print("Prediction Results:") |
| print("="*50) |
| print(f"Class: {prediction['class']}") |
| print(f"Confidence: {prediction['confidence']:.4f}") |
| print("\nAll Probabilities:") |
| for class_name, prob in sorted(prediction['all_probabilities'].items(), |
| key=lambda x: x[1], reverse=True): |
| print(f" {class_name}: {prob:.4f}") |
| print("="*50) |
| |
| |
| if 'annotated_image' in prediction: |
| output_img_path = output_path / f"predicted_{source_path.name}" |
| cv2.imwrite(str(output_img_path), prediction['annotated_image']) |
| print(f"\nAnnotated image saved to: {output_img_path}") |
| |
| |
| if args.save_txt: |
| txt_path = output_path / f"prediction_{source_path.stem}.txt" |
| with open(txt_path, 'w') as f: |
| f.write(f"Class: {prediction['class']}\n") |
| f.write(f"Confidence: {prediction['confidence']:.4f}\n\n") |
| f.write("All Probabilities:\n") |
| for class_name, prob in sorted(prediction['all_probabilities'].items(), |
| key=lambda x: x[1], reverse=True): |
| f.write(f"{class_name}: {prob:.4f}\n") |
| |
| if args.save_json: |
| json_path = output_path / f"prediction_{source_path.stem}.json" |
| with open(json_path, 'w') as f: |
| json.dump(prediction, f, indent=2) |
| |
| elif source_path.is_dir(): |
| |
| image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} |
| image_paths = [p for p in source_path.iterdir() |
| if p.suffix.lower() in image_extensions] |
| |
| print(f"Processing {len(image_paths)} images...") |
| predictions = classifier.predict_batch([str(p) for p in image_paths], args.conf) |
| |
| |
| results = [] |
| for img_path, pred in zip(image_paths, predictions): |
| results.append({ |
| 'image': str(img_path), |
| 'prediction': pred |
| }) |
| |
| if args.save_json: |
| json_path = output_path / "batch_predictions.json" |
| with open(json_path, 'w') as f: |
| json.dump(results, f, indent=2) |
| print(f"\nResults saved to: {json_path}") |
| |
| |
| print("\n" + "="*50) |
| print("Batch Prediction Summary:") |
| print("="*50) |
| class_counts = {} |
| for pred in predictions: |
| cls = pred['class'] |
| class_counts[cls] = class_counts.get(cls, 0) + 1 |
| |
| for cls, count in sorted(class_counts.items()): |
| print(f"{cls}: {count} images") |
| print("="*50) |
| |
| else: |
| print(f"Error: Source path does not exist: {source_path}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|