Girishug commited on
Commit
4e2babc
·
verified ·
1 Parent(s): d06b805

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py CHANGED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision import transforms
4
+ import torchvision.models.detection as detection
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import numpy as np
8
+ import cv2
9
+
10
+ # Load the trained model
11
+ model = detection.fasterrcnn_resnet50_fpn(pretrained=False)
12
+ num_classes = 91 # COCO has 80 classes + 1 background
13
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
14
+ model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
15
+
16
+ # Load the model weights
17
+ model.load_state_dict(torch.load('final_model.pth', weights_only=True))
18
+ model.eval()
19
+
20
+ # Define transformations
21
+ transform = transforms.Compose([
22
+ transforms.Resize((600, 600)),
23
+ transforms.ToTensor(),
24
+ ])
25
+
26
+ # Prediction function
27
+ def predict(image):
28
+ image = transform(image).unsqueeze(0) # Add batch dimension
29
+ with torch.no_grad():
30
+ predictions = model(image)
31
+
32
+ # Process predictions
33
+ boxes = predictions[0]['boxes'].cpu().numpy()
34
+ scores = predictions[0]['scores'].cpu().numpy()
35
+ labels = predictions[0]['labels'].cpu().numpy()
36
+
37
+ # Filter out low-confidence predictions
38
+ threshold = 0.5
39
+ boxes = boxes[scores > threshold]
40
+ labels = labels[scores > threshold]
41
+
42
+ # Draw boxes on the image
43
+ image_np = np.array(image.squeeze().permute(1, 2, 0).cpu())
44
+ for box, label in zip(boxes, labels):
45
+ x1, y1, x2, y2 = box.astype(int)
46
+ image_np = cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2)
47
+ image_np = cv2.putText(image_np, str(label), (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
48
+
49
+ return Image.fromarray(image_np)
50
+
51
+ # Gradio interface
52
+ iface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="image", title="Object Detection with Faster R-CNN")
53
+ iface.launch()