JakeTurner616's picture
Update app.py
5120fbc verified
import gradio as gr
import torch
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
# Load trained YOLO11 model
model_path = "best.pt"
model = YOLO(model_path)
# Class names
CLASS_NAMES = [
"card_title", "card_art", "card_type",
"card_set_symbol", "card_mana_cost",
"card_oracle_text", "card_power_toughness"
]
# Define inference function
def segment_card(image):
image = np.array(image) # Convert PIL image to NumPy array
results = model(image) # Run YOLO inference
# Convert to OpenCV format
annotated_image = image.copy()
# Dictionary to track the highest confidence detection per class
best_detections = {}
# Extract bounding boxes and labels
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0]) # Bounding box coordinates
class_id = int(box.cls[0]) # Class index
confidence = box.conf[0].item() # Confidence score
# Check if this is the highest confidence detection for the class
if class_id not in best_detections or confidence > best_detections[class_id]["confidence"]:
best_detections[class_id] = {
"bbox": (x1, y1, x2, y2),
"confidence": confidence
}
# Draw the highest confidence detections
for class_id, detection in best_detections.items():
x1, y1, x2, y2 = detection["bbox"]
label = CLASS_NAMES[class_id]
confidence = detection["confidence"]
# Draw bounding box **BELOW** text elements
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Set text properties
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.8 # Increased font size for better readability
font_thickness = 2
label_text = f"{label} ({confidence:.2f})"
# Get text size for proper background padding
text_size = cv2.getTextSize(label_text, font, font_scale, font_thickness)[0]
text_x, text_y = x1, y1 - 10
# Ensure text doesn't go out of bounds
text_y = max(text_y, text_size[1] + 10)
# Draw **filled rectangle background** for the text (above bounding box)
cv2.rectangle(
annotated_image,
(text_x, text_y - text_size[1] - 5),
(text_x + text_size[0] + 5, text_y + 5),
(0, 255, 0), # Background color (Green)
-1
)
# Draw the **text label above the rectangle**
cv2.putText(
annotated_image,
label_text,
(text_x, text_y),
font,
font_scale,
(0, 0, 0), # Text color (Black for contrast)
font_thickness
)
return Image.fromarray(annotated_image) # Convert back to PIL Image
# Create Gradio UI
iface = gr.Interface(
fn=segment_card,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="MTG Card Segmentation with YOLO11",
description="Upload a Magic: The Gathering card image, and the model will segment key visual elements with labels. (Works best with card scans)"
)
# Launch the app
iface.launch()