Spaces:
Running
Running
| import gradio as gr | |
| import cv2 | |
| import easyocr | |
| import numpy as np | |
| from rfdetr import RFDETRSmall | |
| import tempfile | |
| import supervision as sv | |
| import pandas as pd | |
| # Load models globally so they're only initialized once | |
| model = RFDETRSmall(pretrain_weights="rfdetr_small_best.pth") | |
| reader = easyocr.Reader(['en', 'es'], gpu=True) | |
| CLASSES = { | |
| 0: "Diana Product", | |
| 1: "Gallo Product", | |
| 2: "Raptor bottel", | |
| 3: "Tortrix Product", | |
| 4: "cocacola pepsi", | |
| 5: "laky ice cream" | |
| } | |
| PRODUCT_KEYWORDS = { | |
| "diana": "Diana Product", | |
| "gallo": "Gallo Product", | |
| "raptor": "Raptor bottel", | |
| "tortrix": "Tortrix Product", | |
| "coca": "cocacola pepsi", | |
| "pepsi": "cocacola pepsi", | |
| "laky": "laky ice cream", | |
| } | |
| box_annotator = sv.BoxAnnotator() | |
| label_annotator = sv.LabelAnnotator() | |
| def calculate_iou(boxA, boxB): | |
| xA = max(boxA[0], boxB[0]) | |
| yA = max(boxA[1], boxB[1]) | |
| xB = min(boxA[2], boxB[2]) | |
| yB = min(boxA[3], boxB[3]) | |
| interArea = max(0, xB - xA) * max(0, yB - yA) | |
| boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) | |
| boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) | |
| iou = interArea / float(boxAArea + boxBArea - interArea) | |
| return iou | |
| def process_frame(frame, conf_threshold): | |
| """Processes a single BGR frame, applying confidence, area, and OCR logic.""" | |
| height, width, _ = frame.shape | |
| total_area = height * width | |
| # Predict using RF-DETR | |
| detections = model.predict(frame) | |
| # Apply confidence filter | |
| detections = detections[detections.confidence >= conf_threshold] | |
| # Area filter (>30% rejection) | |
| valid_indices = [] | |
| for i, bbox in enumerate(detections.xyxy): | |
| x1, y1, x2, y2 = bbox | |
| area = (x2 - x1) * (y2 - y1) | |
| if area <= 0.30 * total_area: | |
| valid_indices.append(i) | |
| detections = detections[valid_indices] | |
| # Resolve conflicting classes via EasyOCR | |
| final_class_ids = detections.class_id.copy() | |
| for i in range(len(detections)): | |
| for j in range(i + 1, len(detections)): | |
| if detections.class_id[i] == detections.class_id[j]: | |
| continue | |
| iou = calculate_iou(detections.xyxy[i], detections.xyxy[j]) | |
| if iou > 0.7: | |
| # OCR on union crop | |
| boxA = detections.xyxy[i] | |
| boxB = detections.xyxy[j] | |
| crop_x1 = int(min(boxA[0], boxB[0])) | |
| crop_y1 = int(min(boxA[1], boxB[1])) | |
| crop_x2 = int(max(boxA[2], boxB[2])) | |
| crop_y2 = int(max(boxA[3], boxB[3])) | |
| cropped_img = frame[max(0, crop_y1):min(height, crop_y2), max(0, crop_x1):min(width, crop_x2)] | |
| # Try OCR if crop is valid | |
| matched_class = None | |
| if cropped_img.size > 0: | |
| ocr_results = reader.readtext(cropped_img) | |
| ocr_text = " ".join([res[1].lower() for res in ocr_results]) | |
| for keyword, product in PRODUCT_KEYWORDS.items(): | |
| if keyword in ocr_text: | |
| matched_class = product | |
| break | |
| if matched_class: | |
| # Match OCR text to class ID | |
| class_id_matched = next((k for k, v in CLASSES.items() if v == matched_class), None) | |
| if class_id_matched is not None: | |
| final_class_ids[i] = class_id_matched | |
| final_class_ids[j] = class_id_matched | |
| else: | |
| # Fallback: keep the one with higher confidence | |
| if detections.confidence[i] > detections.confidence[j]: | |
| final_class_ids[j] = final_class_ids[i] | |
| else: | |
| final_class_ids[i] = final_class_ids[j] | |
| detections.class_id = final_class_ids | |
| # Annotate frame | |
| labels = [f"{CLASSES.get(class_id, 'Unknown')} {conf:.2f}" for class_id, conf in zip(detections.class_id, detections.confidence)] | |
| annotated_frame = box_annotator.annotate(scene=frame.copy(), detections=detections) | |
| annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) | |
| # Generate summary | |
| counts = {} | |
| for class_id in detections.class_id: | |
| name = CLASSES.get(class_id, "Unknown") | |
| counts[name] = counts.get(name, 0) + 1 | |
| return annotated_frame, counts | |
| def process_image(image, conf_threshold): | |
| if image is None: | |
| return None, pd.DataFrame(columns=["Class Name", "Count"]) | |
| bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| annotated_bgr, counts = process_frame(bgr_image, conf_threshold) | |
| annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) | |
| summary_data = [{"Class Name": name, "Count": count} for name, count in counts.items()] | |
| df = pd.DataFrame(summary_data) | |
| if df.empty: | |
| df = pd.DataFrame(columns=["Class Name", "Count"]) | |
| return annotated_rgb, df | |
| def process_video(video_path, conf_threshold): | |
| if not video_path: | |
| return None, pd.DataFrame(columns=["Class Name", "Count"]) | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| frame_interval = 15 | |
| out_fps = fps / frame_interval if fps > 0 else 2.0 | |
| temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| temp_out_path = temp_out.name | |
| temp_out.close() | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(temp_out_path, fourcc, out_fps, (width, height)) | |
| frame_count = 0 | |
| max_counts = {} | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % frame_interval == 0: | |
| annotated_frame, counts = process_frame(frame, conf_threshold) | |
| out.write(annotated_frame) | |
| # Keep track of the maximum count of each item seen simultaneously in any frame | |
| for name, count in counts.items(): | |
| if count > max_counts.get(name, 0): | |
| max_counts[name] = count | |
| frame_count += 1 | |
| cap.release() | |
| out.release() | |
| summary_data = [{"Class Name": name, "Max Count (per frame)": count} for name, count in max_counts.items()] | |
| df = pd.DataFrame(summary_data) | |
| if df.empty: | |
| df = pd.DataFrame(columns=["Class Name", "Max Count (per frame)"]) | |
| return temp_out_path, df | |
| # Gradio Interface | |
| theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="slate", | |
| neutral_hue="slate", | |
| ).set( | |
| body_background_fill="*neutral_950", | |
| body_text_color="*neutral_100", | |
| block_background_fill="*neutral_900", | |
| block_label_text_color="*neutral_200", | |
| ) | |
| with gr.Blocks(theme=theme) as app: | |
| gr.Markdown("# 🛒 Retail Product Detection System — Demo") | |
| gr.Markdown("### please upload product images/videos") | |
| with gr.Tab("Image Detection"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="numpy", label="Upload Product Image") | |
| img_conf_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold") | |
| img_submit_btn = gr.Button("Detect Products", variant="primary") | |
| with gr.Column(): | |
| image_output = gr.Image(type="numpy", label="Annotated Output") | |
| img_summary_table = gr.Dataframe(headers=["Class Name", "Count"], label="Detection Summary") | |
| img_submit_btn.click( | |
| fn=process_image, | |
| inputs=[image_input, img_conf_slider], | |
| outputs=[image_output, img_summary_table] | |
| ) | |
| with gr.Tab("Video Detection"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Upload Counter Video") | |
| vid_conf_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold") | |
| vid_submit_btn = gr.Button("Detect Products in Video", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Annotated Output (15th frame intervals)") | |
| vid_summary_table = gr.Dataframe(headers=["Class Name", "Max Count (per frame)"], label="Detection Summary") | |
| vid_submit_btn.click( | |
| fn=process_video, | |
| inputs=[video_input, vid_conf_slider], | |
| outputs=[video_output, vid_summary_table] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |