image_detector / app.py
jeyasee's picture
Update app.py
f16d728 verified
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FastRCNNPredictor
from torchvision.transforms import ToTensor
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import io
import pandas as pd
# Define classes (COCO-style)
COCO_CLASSES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A',
'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A',
'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
# Load model
def load_model(checkpoint_path=None, num_classes=91):
model = fasterrcnn_resnet50_fpn(pretrained=True)
if checkpoint_path:
# Replace head for custom classes if needed
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()
return model
model = load_model("frcnn_model.pth") # Replace with None to use default COCO model
# Run inference
def detect_objects(image):
image = image.convert("RGB")
image_tensor = ToTensor()(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image_tensor)
boxes = outputs[0]['boxes']
labels = outputs[0]['labels']
scores = outputs[0]['scores']
threshold = 0.5
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
log_data = []
for box, label, score in zip(boxes, labels, scores):
if score >= threshold:
box = box.tolist()
label_name = COCO_CLASSES[label.item()]
draw.rectangle(box, outline="red", width=2)
draw.text((box[0], box[1]), f"{label_name} ({score:.2f})", fill="yellow", font=font)
log_data.append({"Object": label_name, "Score": round(score.item(), 2)})
# Prepare image for download
buffer = io.BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
# Prepare table
table = pd.DataFrame(log_data) if log_data else pd.DataFrame(columns=["Object", "Score"])
return image, buffer, table
# Gradio interface
app = gr.Interface(
fn=detect_objects,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=[
gr.Image(type="pil", label="Detected Image"),
gr.File(label="Download Image"),
gr.Dataframe(headers=["Object", "Score"], label="Detection Log")
],
title="🧠 Object Detection App (Faster R-CNN)",
description="Upload an image to detect objects using a pretrained or custom Faster R-CNN model. View logs and download the result."
)
if __name__ == "__main__":
app.launch()