File size: 5,196 Bytes
ea90d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import cv2
import PIL.Image
import numpy as np
from ultralytics import YOLO
import pandas as pd

# --- Model Loading ---
MODEL_PATH = "best.pt"
try:
    model = YOLO(MODEL_PATH)
    print(f"Model loaded successfully from {MODEL_PATH}")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

def predict(image, conf_threshold, iou_threshold):
    """
    Runs YOLO inference on the input image.
    Args:
        image: Input image (numpy array or PIL Image).
        conf_threshold: Confidence threshold for detection.
        iou_threshold: IoU threshold for NMS.
    Returns:
        Annotated image (numpy array), Class counts (dict/str), Detailed Data (DataFrame)
    """
    if model is None:
        return None, "Model not loaded.", None

    try:
        # Run inference
        results = model.predict(image, conf=conf_threshold, iou=iou_threshold)
        result = results[0]
        
        # Plot results
        res_plotted = result.plot()
        res_image = res_plotted[..., ::-1] # Convert BGR to RGB if needed, specifically for Gradio image output which usually expects RGB
        
        # Count classes
        class_counts = {}
        box_data = []
        
        for box in result.boxes:
            cls = int(box.cls[0])
            cls_name = model.names[cls]
            class_counts[cls_name] = class_counts.get(cls_name, 0) + 1
            
            box_data.append({
                "Class": cls_name,
                "Confidence": float(box.conf[0]),
                "Coordinates": [round(x, 1) for x in box.xyxy[0].tolist()]
            })
            
        # Format class counts for display
        counts_summary = pd.DataFrame(list(class_counts.items()), columns=['Class', 'Count'])
        
        # Detailed data
        df = pd.DataFrame(box_data)
        
        return res_image, counts_summary, df

    except Exception as e:
        return None, f"Error: {e}", None

# --- Gradio UI ---
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown(
            """
            #  CAt and DOG Detection Pro
            Upload an image to detect Cat and Dog.
            """
        )
        
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Original Image", type="numpy")
                conf_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold")
                iou_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IoU Threshold")
                run_btn = gr.Button("🔍 Run Detection", variant="primary")
            
            with gr.Column():
                output_image = gr.Image(label="Detected Output")
                gr.Markdown("### 📊 Detection Statistics")
                output_counts = gr.Dataframe(label="Class Counts")
                output_details = gr.Dataframe(label="Detailed Detection Data")

        run_btn.click(
            fn=predict,
            inputs=[input_image, conf_slider, iou_slider],
            outputs=[output_image, output_counts, output_details]
        )
        
        gr.Markdown("---")
        gr.Markdown("Model: standard YOLOv8n (Custom Trained) | Classes: Head, Helmet, Person")

    return demo

if __name__ == "__main__":
    demo = create_interface()
    # Fix: Pass theme to launch() as per Gradio 6.0 warning (though usually it's passed to Blocks, the warning said otherwise? Let's re-read carefully)
    # Warning: "The parameters have been moved from the Blocks constructor to the launch() method in Gradio 6.0: theme."
    # Wait, usually theme is in Blocks. Gradio 6.0 is very new or future? The installed version is 6.2.0? 
    # Let me check the Pip output: "Requirement already satisfied: gradio in ... (6.2.0)"
    # Okay, so for 6.2.0, theme might indeed be in launch() or it's a specific change.
    # Actually, standard Gradio usually has theme in Blocks. But if the warning says so, I'll move it.
    # However, passing theme to launch() might not work if the object hasn't been built with it?
    # No, launch() handles the serving. Theme is a UI property.
    # I'll try passing it to launch(), but launch() args are usually server config.
    # Actually, Blocks(theme=...) is the standard way. Maybe the warning is about something else or I misread?
    # "UserWarning: The parameters have been moved from the Blocks constructor to the launch() method in Gradio 6.0: theme."
    # That seems quite explicit.
    # Let's try `demo.launch(theme=gr.themes.Soft())`? No, launch takes `auth`, `server_name` etc.
    # Wait, maybe `gr.Interface(theme=...)` moved? But I'm using `Blocks`.
    # Let's try following the warning.
    demo.launch(theme=gr.themes.Soft()) # This relies on launch accepting theme?
    # If launch doesn't accept theme, this will fail.
    # Let's verify if I can check doc. But I can't browse.
    # I'll stick to the warning.
    # But wait, `theme` object needs to be constructed.
    # Let's try:
    # demo = create_interface() -> returns Blocks object
    # demo.launch()
    # The replacement content below removes it from Blocks() and adds it to launch().