Gk-Rohan commited on
Commit
a89067a
·
1 Parent(s): d112f2c

feat: Control threshold

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -18,10 +18,10 @@ model = RFDETRBase(pretrain_weights="checkpoint_best_regular.pth")
18
  box_annotator = sv.BoxAnnotator()
19
  label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)
20
 
21
- def detect_objects_and_recognize_logos(image):
22
  try:
23
- # Run inference with RFDETR
24
- detections = model.predict(image, threshold=0.2)
25
 
26
  # Initialize labels for detection and recognition frames
27
  detection_labels = []
@@ -101,14 +101,17 @@ def detect_objects_and_recognize_logos(image):
101
  # Create Gradio interface
102
  interface = gr.Interface(
103
  fn=detect_objects_and_recognize_logos,
104
- inputs=gr.Image(type="pil", label="Upload Image"),
 
 
 
105
  outputs=[
106
  gr.Image(type="pil", label="Detection Frame (RFDETR)"),
107
  gr.Image(type="pil", label="Recognition Frame (RFDETR + Gemini)"),
108
  gr.Textbox(label="Detected Brand Names")
109
  ],
110
  title="Object Detection and Logo Recognition with RFDETR and Gemini",
111
- description="Upload an image to detect objects using RFDETR model and recognize logos using Google Gemini. Outputs include a detection frame (objects only) and a recognition frame (objects with brand names)."
112
  )
113
 
114
  # Launch the interface
 
18
  box_annotator = sv.BoxAnnotator()
19
  label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)
20
 
21
+ def detect_objects_and_recognize_logos(image, threshold):
22
  try:
23
+ # Run inference with RFDETR using the provided threshold
24
+ detections = model.predict(image, threshold=threshold)
25
 
26
  # Initialize labels for detection and recognition frames
27
  detection_labels = []
 
101
  # Create Gradio interface
102
  interface = gr.Interface(
103
  fn=detect_objects_and_recognize_logos,
104
+ inputs=[
105
+ gr.Image(type="pil", label="Upload Image"),
106
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="Confidence Threshold")
107
+ ],
108
  outputs=[
109
  gr.Image(type="pil", label="Detection Frame (RFDETR)"),
110
  gr.Image(type="pil", label="Recognition Frame (RFDETR + Gemini)"),
111
  gr.Textbox(label="Detected Brand Names")
112
  ],
113
  title="Object Detection and Logo Recognition with RFDETR and Gemini",
114
+ description="Upload an image to detect objects using RFDETR model and recognize logos using Google Gemini. Adjust the confidence threshold to filter detections. Outputs include a detection frame (objects only) and a recognition frame (objects with brand names)."
115
  )
116
 
117
  # Launch the interface