File size: 4,879 Bytes
34b843b
 
 
 
 
 
 
 
 
 
e00c5e0
34b843b
 
 
ad93285
34b843b
 
 
 
 
a89067a
34b843b
a89067a
 
34b843b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a89067a
 
 
 
34b843b
 
 
 
 
 
a89067a
34b843b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import io
from PIL import Image
import supervision as sv
from rfdetr import RFDETRBase
from google import genai
from google.genai import types
import gradio as gr
import os

# Configure Gemini API
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))

# Load RFDETR model
CLASSES = {0: "logo"}
model = RFDETRBase(pretrain_weights="checkpoint_best_regular.pth")

# Create annotators
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)

def detect_objects_and_recognize_logos(image, threshold):
    try:
        # Run inference with RFDETR using the provided threshold
        detections = model.predict(image, threshold=threshold)
        
        # Initialize labels for detection and recognition frames
        detection_labels = []
        recognition_labels = []
        brand_names = []
        
        # Process detections and recognize logos with Gemini
        for i, (box, class_id, confidence) in enumerate(zip(detections.xyxy, detections.class_id, detections.confidence)):
            class_name = CLASSES[class_id]
            box = [round(i) for i in box.tolist()]  # [x_min, y_min, x_max, y_max]
            
            # Create label for detection frame (class name and confidence only)
            detection_label = f"{class_name} {confidence:.2f}"
            detection_labels.append(detection_label)
            
            # Crop the image using the bounding box
            cropped_image = image.crop((box[0], box[1], box[2], box[3]))
            
            # Convert cropped image to bytes
            img_byte_arr = io.BytesIO()
            cropped_image.save(img_byte_arr, format='JPEG')
            image_bytes = img_byte_arr.getvalue()
            
            # Send cropped image to Gemini for logo recognition
            try:
                response = client.models.generate_content(
                    model='gemini-2.0-flash',
                    contents=[
                        types.Part.from_bytes(
                            data=image_bytes,
                            mime_type='image/jpeg',
                        ),
                        'Recognize the brand name for the logo and return only the name of the logo. If you don’t know the brand name, return "Unknown"'
                    ])
                brand_name = response.text.strip() if response.text else "Unknown"
            except Exception as e:
                brand_name = f"Gemini Error: {str(e)}"
            
            # Create label for recognition frame (class name, confidence, and brand name)
            recognition_label = f"{class_name} {confidence:.2f} | Brand: {brand_name}"
            recognition_labels.append(recognition_label)
            brand_names.append(brand_name)
            
            # Print detection details
            print(
                f"Detected {class_name} with confidence {round(confidence, 3)} "
                f"at location {box} | Brand: {brand_name}"
            )
        
        # Annotate detection frame (only class name and confidence)
        detection_frame = label_annotator.annotate(
            scene=image.copy(),
            detections=detections,
            labels=detection_labels
        )
        detection_frame = box_annotator.annotate(
            scene=detection_frame.copy(),
            detections=detections
        )
        
        # Annotate recognition frame (class name, confidence, and brand name)
        recognition_frame = label_annotator.annotate(
            scene=image.copy(),
            detections=detections,
            labels=recognition_labels
        )
        recognition_frame = box_annotator.annotate(
            scene=recognition_frame.copy(),
            detections=detections
        )
        
        return detection_frame, recognition_frame, ", ".join([name for name in brand_names if name != "Unknown"])
    
    except Exception as e:
        return f"Error: {str(e)}", f"Error: {str(e)}", "None"

# Create Gradio interface
interface = gr.Interface(
    fn=detect_objects_and_recognize_logos,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="Confidence Threshold")
    ],
    outputs=[
        gr.Image(type="pil", label="Detection Frame (RFDETR)"),
        gr.Image(type="pil", label="Recognition Frame (RFDETR + Gemini)"),
        gr.Textbox(label="Detected Brand Names")
    ],
    title="Object Detection and Logo Recognition with RFDETR and Gemini",
    description="Upload an image to detect objects using RFDETR model and recognize logos using Google Gemini. Adjust the confidence threshold to filter detections. Outputs include a detection frame (objects only) and a recognition frame (objects with brand names)."
)

# Launch the interface
if __name__ == "__main__":
    interface.launch(share=True)