Gk-Rohan's picture
feat: Control threshold
a89067a
import io
from PIL import Image
import supervision as sv
from rfdetr import RFDETRBase
from google import genai
from google.genai import types
import gradio as gr
import os
# Configure Gemini API
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
# Load RFDETR model
CLASSES = {0: "logo"}
model = RFDETRBase(pretrain_weights="checkpoint_best_regular.pth")
# Create annotators
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)
def detect_objects_and_recognize_logos(image, threshold):
try:
# Run inference with RFDETR using the provided threshold
detections = model.predict(image, threshold=threshold)
# Initialize labels for detection and recognition frames
detection_labels = []
recognition_labels = []
brand_names = []
# Process detections and recognize logos with Gemini
for i, (box, class_id, confidence) in enumerate(zip(detections.xyxy, detections.class_id, detections.confidence)):
class_name = CLASSES[class_id]
box = [round(i) for i in box.tolist()] # [x_min, y_min, x_max, y_max]
# Create label for detection frame (class name and confidence only)
detection_label = f"{class_name} {confidence:.2f}"
detection_labels.append(detection_label)
# Crop the image using the bounding box
cropped_image = image.crop((box[0], box[1], box[2], box[3]))
# Convert cropped image to bytes
img_byte_arr = io.BytesIO()
cropped_image.save(img_byte_arr, format='JPEG')
image_bytes = img_byte_arr.getvalue()
# Send cropped image to Gemini for logo recognition
try:
response = client.models.generate_content(
model='gemini-2.0-flash',
contents=[
types.Part.from_bytes(
data=image_bytes,
mime_type='image/jpeg',
),
'Recognize the brand name for the logo and return only the name of the logo. If you don’t know the brand name, return "Unknown"'
])
brand_name = response.text.strip() if response.text else "Unknown"
except Exception as e:
brand_name = f"Gemini Error: {str(e)}"
# Create label for recognition frame (class name, confidence, and brand name)
recognition_label = f"{class_name} {confidence:.2f} | Brand: {brand_name}"
recognition_labels.append(recognition_label)
brand_names.append(brand_name)
# Print detection details
print(
f"Detected {class_name} with confidence {round(confidence, 3)} "
f"at location {box} | Brand: {brand_name}"
)
# Annotate detection frame (only class name and confidence)
detection_frame = label_annotator.annotate(
scene=image.copy(),
detections=detections,
labels=detection_labels
)
detection_frame = box_annotator.annotate(
scene=detection_frame.copy(),
detections=detections
)
# Annotate recognition frame (class name, confidence, and brand name)
recognition_frame = label_annotator.annotate(
scene=image.copy(),
detections=detections,
labels=recognition_labels
)
recognition_frame = box_annotator.annotate(
scene=recognition_frame.copy(),
detections=detections
)
return detection_frame, recognition_frame, ", ".join([name for name in brand_names if name != "Unknown"])
except Exception as e:
return f"Error: {str(e)}", f"Error: {str(e)}", "None"
# Create Gradio interface
interface = gr.Interface(
fn=detect_objects_and_recognize_logos,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="Confidence Threshold")
],
outputs=[
gr.Image(type="pil", label="Detection Frame (RFDETR)"),
gr.Image(type="pil", label="Recognition Frame (RFDETR + Gemini)"),
gr.Textbox(label="Detected Brand Names")
],
title="Object Detection and Logo Recognition with RFDETR and Gemini",
description="Upload an image to detect objects using RFDETR model and recognize logos using Google Gemini. Adjust the confidence threshold to filter detections. Outputs include a detection frame (objects only) and a recognition frame (objects with brand names)."
)
# Launch the interface
if __name__ == "__main__":
interface.launch(share=True)