mdpg4 / predict_task1.py
Mahiruoshi's picture
Update predict_task1.py
2a9001f verified
import cv2
import os
import time
import importlib.util
import supervision as sv
from ultralytics import YOLO
config_dir = os.path.abspath(os.path.dirname(__file__))
config_path = os.path.join(config_dir, 'PC_CONFIG.py')
spec = importlib.util.spec_from_file_location("PC_CONFIG", config_path)
PC_CONFIG = importlib.util.module_from_spec(spec)
spec.loader.exec_module(PC_CONFIG)
dir = str(os.path.join(PC_CONFIG.BASE_DIR, "weights", "best_task1.pt"))
class Predictor:
def __init__(self):
# Load a pre-trained yolov8n model
print("dir:",dir)
self.model = YOLO(dir) # replace model here
# self.print_class_ids() # Print class IDs upon initialization
# def print_class_ids(self):
# # Print all class names and their corresponding IDs
# for id, name in enumerate(self.model.names):
# print(f"ID: {id}, Name: {name}")
# def predict_id(self, image_file_path, task_type):
# # Load the image
# image = cv2.imread(image_file_path)
# # Run inference on the image
# results = self.model(image)
# # Print results
# print(results)
# # Show annotation
# self.show_annotation(image, results)
# # Extract class name
# class_name, largest_size, detection_id = None, -1, None
# for result in results: # Assuming 'results' is a list
# print(f"task_type is {task_type}")
# if task_type == "TASK_2":
# for prediction in result.predictions:
# print(prediction)
# class_name = prediction.class_name
# detection_id = prediction.detection_id
# if class_name != "Bullseye":
# break
# else:
# for prediction in result.predictions:
# print(prediction)
# if largest_size == -1 or max(prediction.width, prediction.height) > largest_size:
# largest_size = max(prediction.width, prediction.height)
# class_name = prediction.class_name
# detection_id = prediction.detection_id
# if class_name:
# print("class_name = " + class_name)
# else:
# print("class_name = None")
# return class_name, results, detection_id
def predict_id(self, image_file_path, task_type):
# Load the image
image = cv2.imread(image_file_path)
# Validation for image existence
if image is None:
print(f"Error: Could not read image at {image_file_path}")
return None, None, None
# Check the image size and resize if necessary
if image.shape[0] != 640 or image.shape[1] != 640:
image = cv2.resize(image, (640, 640)) # Resize to 640x640
# Run inference on the image
results = self.model(image) # Directly pass the image
# Print results
print(results)
# Show annotation (using YOLOv8's plotting capabilities)
# results[0].show()
# Extract class name, largest size, and detection ID
class_name, largest_size, detection_id = None, -1, 0
boxes = results[0].boxes.xyxy # Get bounding boxes (x1, y1, x2, y2)
scores = results[0].boxes.conf # Get confidence scores
class_ids = results[0].boxes.cls # Get class IDs
for i in range(len(boxes)):
print(f"task_type is {task_type}")
if task_type == "TASK_2":
if int(class_ids[i]) != 16: # Replace with the ID for "Bulls Eye"
class_name = results[0].names[int(class_ids[i])] # Get class name
detection_id = i
break
else:
print("Bullseye detected")
class_name = "bullseye"
detection_id = i
else:
# Determine the largest bounding box
box_width = boxes[i][2] - boxes[i][0]
box_height = boxes[i][3] - boxes[i][1]
size = max(box_width, box_height)
detection_id = i
if largest_size == -1 or size > largest_size:
largest_size = size
class_name = results[0].names[int(class_ids[i])]
detection_id = i
if class_name:
print("class_name = " + class_name)
timestamp = int(time.time())
# Save the annotated image (removed - will be handled by app.py)
# try:
# results[detection_id].save(f'../data/annotated_images/{class_name}_{timestamp}.jpg')
# except:
# print("error in saving photo!")
else:
print("class_name = None")
return class_name, results, detection_id
# def show_annotation(self, image, results):
# # Create supervision annotators
# bounding_box_annotator = sv.BoundingBoxAnnotator()
# label_annotator = sv.LabelAnnotator()
# # Process results from YOLOv8
# detections = []
# for result in results:
# for detection in result.boxes.data: # Accessing YOLOv8's box data
# class_id = int(detection[5]) # Class ID
# x1, y1, x2, y2 = map(int, detection[:4]) # Bounding box coordinates
# score = float(detection[4]) # Confidence score
# # Add to detections
# detections.append({
# "bbox": [x1, y1, x2, y2],
# "confidence": score,
# "class_id": class_id
# })
# # Convert detections to the expected format for supervision
# if detections:
# detections = sv.Detections(
# xyxy=[d["bbox"] for d in detections],
# confidence=[d["confidence"] for d in detections],
# class_id=[d["class_id"] for d in detections]
# )
# # Annotate the image with inference results
# annotated_image = bounding_box_annotator.annotate(scene=image, detections=detections)
# annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
# # Display the annotated image
# try:
# cv2.imshow("Annotated Image", annotated_image)
# cv2.waitKey(0) # Wait indefinitely until a key is pressed
# except Exception as e:
# print(f"Error displaying image: {e}")
# finally:
# cv2.destroyAllWindows() # Close all OpenCV windows
# else:
# print("No detections found.")
if __name__ == "__main__":
# Example usage
predictor = Predictor()
# Specify the path to your image
image_file_path = os.path.join(PC_CONFIG.FILE_DIRECTORY, "image-rec", "sample_images", "IMG_9325.jpg")
# Predict and display the class name
predictor.predict_id(image_file_path, "TASK_1")