File size: 4,098 Bytes
3dffe7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import torch
from PIL import Image, ImageDraw, ImageFont
import numpy as np

# Load your model
MODEL_ID = "Meenu047/RGTB_Aerial_view_detection"

print("Loading model...")
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForObjectDetection.from_pretrained(MODEL_ID)
print("Model loaded successfully!")

def predict(image):
    """
    Run object detection on the input image
    """
    if image is None:
        return None, "Please upload an image"
    
    # Prepare image
    inputs = processor(images=image, return_tensors="pt")
    
    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Post-process results
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(
        outputs, 
        target_sizes=target_sizes, 
        threshold=0.5
    )[0]
    
    # Draw bounding boxes
    draw = ImageDraw.Draw(image)
    
    # Try to use a nice font, fallback to default if not available
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
    except:
        font = ImageFont.load_default()
    
    detections = []
    colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan']
    
    for idx, (score, label, box) in enumerate(zip(results["scores"], results["labels"], results["boxes"])):
        box = [round(i, 2) for i in box.tolist()]
        confidence = round(score.item(), 3)
        label_name = model.config.id2label[label.item()]
        
        # Draw rectangle
        color = colors[idx % len(colors)]
        draw.rectangle(box, outline=color, width=3)
        
        # Draw label
        text = f"{label_name}: {confidence:.2f}"
        text_bbox = draw.textbbox((box[0], box[1]), text, font=font)
        draw.rectangle(text_bbox, fill=color)
        draw.text((box[0], box[1]), text, fill='white', font=font)
        
        detections.append({
            "Label": label_name,
            "Confidence": f"{confidence * 100:.1f}%",
            "Box": f"({int(box[0])}, {int(box[1])}) - ({int(box[2])}, {int(box[3])})"
        })
    
    # Create results text
    if len(detections) == 0:
        results_text = "No objects detected with confidence > 50%"
    else:
        results_text = f"**Detected {len(detections)} object(s):**\n\n"
        for i, det in enumerate(detections, 1):
            results_text += f"**{i}. {det['Label']}**\n"
            results_text += f"   - Confidence: {det['Confidence']}\n"
            results_text += f"   - Location: {det['Box']}\n\n"
    
    return image, results_text

# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🚁 RGTB Aerial View Detection
        Upload an aerial image to detect objects using the trained model.
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                type="pil", 
                label="Upload Aerial Image",
                height=400
            )
            predict_btn = gr.Button("πŸ” Run Detection", variant="primary", size="lg")
            
        with gr.Column():
            output_image = gr.Image(
                type="pil", 
                label="Detection Results",
                height=400
            )
            output_text = gr.Markdown(label="Detected Objects")
    
    gr.Examples(
        examples=[],  # Add example images here if you have any
        inputs=input_image,
    )
    
    predict_btn.click(
        fn=predict,
        inputs=input_image,
        outputs=[output_image, output_text]
    )
    
    gr.Markdown(
        """
        ### How to use:
        1. Upload or drag & drop an aerial image
        2. Click "Run Detection" button
        3. View the detected objects with bounding boxes and confidence scores
        
        **Model:** `Meenu047/RGTB_Aerial_view_detection`
        """
    )

if __name__ == "__main__":
    demo.launch()