Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision.models.detection import fasterrcnn_resnet50_fpn, FastRCNNPredictor | |
| from torchvision.transforms import ToTensor | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| import io | |
| import pandas as pd | |
| # Define classes (COCO-style) | |
| COCO_CLASSES = [ | |
| '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', | |
| 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
| 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', | |
| 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', | |
| 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |
| 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', | |
| 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', | |
| 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |
| 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', | |
| 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', | |
| 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', | |
| 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
| ] | |
| # Load model | |
| def load_model(checkpoint_path=None, num_classes=91): | |
| model = fasterrcnn_resnet50_fpn(pretrained=True) | |
| if checkpoint_path: | |
| # Replace head for custom classes if needed | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
| model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) | |
| model.eval() | |
| return model | |
| model = load_model("frcnn_model.pth") # Replace with None to use default COCO model | |
| # Run inference | |
| def detect_objects(image): | |
| image = image.convert("RGB") | |
| image_tensor = ToTensor()(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| boxes = outputs[0]['boxes'] | |
| labels = outputs[0]['labels'] | |
| scores = outputs[0]['scores'] | |
| threshold = 0.5 | |
| draw = ImageDraw.Draw(image) | |
| font = ImageFont.load_default() | |
| log_data = [] | |
| for box, label, score in zip(boxes, labels, scores): | |
| if score >= threshold: | |
| box = box.tolist() | |
| label_name = COCO_CLASSES[label.item()] | |
| draw.rectangle(box, outline="red", width=2) | |
| draw.text((box[0], box[1]), f"{label_name} ({score:.2f})", fill="yellow", font=font) | |
| log_data.append({"Object": label_name, "Score": round(score.item(), 2)}) | |
| # Prepare image for download | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| # Prepare table | |
| table = pd.DataFrame(log_data) if log_data else pd.DataFrame(columns=["Object", "Score"]) | |
| return image, buffer, table | |
| # Gradio interface | |
| app = gr.Interface( | |
| fn=detect_objects, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=[ | |
| gr.Image(type="pil", label="Detected Image"), | |
| gr.File(label="Download Image"), | |
| gr.Dataframe(headers=["Object", "Score"], label="Detection Log") | |
| ], | |
| title="🧠 Object Detection App (Faster R-CNN)", | |
| description="Upload an image to detect objects using a pretrained or custom Faster R-CNN model. View logs and download the result." | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() |