Update app.py
Browse files
app.py
CHANGED
|
@@ -8,18 +8,18 @@ from torchvision.models.detection import FasterRCNN
|
|
| 8 |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
| 9 |
|
| 10 |
# Load Models
|
| 11 |
-
def load_model( backbone_name, num_classes):
|
| 12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 13 |
if backbone_name == "resnet50":
|
| 14 |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
|
| 15 |
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
| 16 |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
| 17 |
-
model.load_state_dict(torch.load(
|
| 18 |
elif backbone_name == "mobilenet":
|
| 19 |
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False)
|
| 20 |
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
| 21 |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
| 22 |
-
model.load_state_dict(torch.load(
|
| 23 |
model.to(device)
|
| 24 |
model.eval()
|
| 25 |
return model
|
|
@@ -60,7 +60,15 @@ def predict_video(video_path, model):
|
|
| 60 |
cv2.putText(frame, f"{class_names[label]}: {score:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
| 61 |
frames.append(frame)
|
| 62 |
cap.release()
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# Gradio Interface for Image and Video Inference
|
| 66 |
|
|
@@ -70,14 +78,14 @@ inputs_image = [gr.Image(type="filepath", label="Upload Image"), model_selection
|
|
| 70 |
outputs_image = gr.Image(type="numpy", label="Detection Output")
|
| 71 |
|
| 72 |
inputs_video = [gr.Video(label="Upload Video"), model_selection]
|
| 73 |
-
outputs_video = gr.
|
| 74 |
|
| 75 |
|
| 76 |
|
| 77 |
with gr.Blocks() as demo:
|
| 78 |
with gr.TabItem("Image"):
|
| 79 |
gr.Interface(
|
| 80 |
-
fn=lambda img, model_name: predict_image(img, load_model( model_name.lower(), num_classes=6)),
|
| 81 |
inputs=inputs_image,
|
| 82 |
outputs=outputs_image,
|
| 83 |
title="Image Inference"
|
|
@@ -85,11 +93,10 @@ with gr.Blocks() as demo:
|
|
| 85 |
|
| 86 |
with gr.TabItem("Video"):
|
| 87 |
gr.Interface(
|
| 88 |
-
fn=lambda vid, model_name: predict_video(vid, load_model(model_name.lower(), num_classes=6)),
|
| 89 |
inputs=inputs_video,
|
| 90 |
outputs=outputs_video,
|
| 91 |
title="Video Inference"
|
| 92 |
)
|
| 93 |
|
| 94 |
demo.launch()
|
| 95 |
-
|
|
|
|
| 8 |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
| 9 |
|
| 10 |
# Load Models
|
| 11 |
+
def load_model(model_path, backbone_name, num_classes):
|
| 12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 13 |
if backbone_name == "resnet50":
|
| 14 |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
|
| 15 |
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
| 16 |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
| 17 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 18 |
elif backbone_name == "mobilenet":
|
| 19 |
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False)
|
| 20 |
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
| 21 |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
| 22 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 23 |
model.to(device)
|
| 24 |
model.eval()
|
| 25 |
return model
|
|
|
|
| 60 |
cv2.putText(frame, f"{class_names[label]}: {score:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
| 61 |
frames.append(frame)
|
| 62 |
cap.release()
|
| 63 |
+
output_path = 'output_video.mp4'
|
| 64 |
+
height, width, _ = frames[0].shape
|
| 65 |
+
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 20, (width, height))
|
| 66 |
+
|
| 67 |
+
for frame in frames:
|
| 68 |
+
out.write(frame)
|
| 69 |
+
|
| 70 |
+
out.release()
|
| 71 |
+
return output_path
|
| 72 |
|
| 73 |
# Gradio Interface for Image and Video Inference
|
| 74 |
|
|
|
|
| 78 |
outputs_image = gr.Image(type="numpy", label="Detection Output")
|
| 79 |
|
| 80 |
inputs_video = [gr.Video(label="Upload Video"), model_selection]
|
| 81 |
+
outputs_video = gr.Video(label="Detection Output")
|
| 82 |
|
| 83 |
|
| 84 |
|
| 85 |
with gr.Blocks() as demo:
|
| 86 |
with gr.TabItem("Image"):
|
| 87 |
gr.Interface(
|
| 88 |
+
fn=lambda img, model_name: predict_image(img, load_model(f'fasterrcnn{model_name}.pth', model_name.lower(), num_classes=6)),
|
| 89 |
inputs=inputs_image,
|
| 90 |
outputs=outputs_image,
|
| 91 |
title="Image Inference"
|
|
|
|
| 93 |
|
| 94 |
with gr.TabItem("Video"):
|
| 95 |
gr.Interface(
|
| 96 |
+
fn=lambda vid, model_name: predict_video(vid, load_model(f'fasterrcnn{model_name}.pth', model_name.lower(), num_classes=6)),
|
| 97 |
inputs=inputs_video,
|
| 98 |
outputs=outputs_video,
|
| 99 |
title="Video Inference"
|
| 100 |
)
|
| 101 |
|
| 102 |
demo.launch()
|
|
|