ma4389's picture
Update app.py
6c00a84 verified
import torch
import torch.nn as nn
import gradio as gr
from ultralytics import YOLO
from ultralytics.nn.tasks import DetectionModel
from ultralytics.nn.modules.conv import Conv
from PIL import Image
import numpy as np
# ---- FIX for PyTorch 2.6+ ----
torch.serialization.add_safe_globals([DetectionModel, nn.Sequential, Conv])
# ---- Device setup (CPU for Spaces) ----
device = "cpu"
# ---- Load model once ----
model = YOLO("best.pt")
model.to(device)
# ---- Prediction function ----
def predict(image):
if image is None:
return None
# Run inference
results = model.predict(image, conf=0.25, device=device)
# Get annotated image
annotated = results[0].plot()
# Convert BGR → RGB
annotated = annotated[:, :, ::-1]
return Image.fromarray(annotated)
# ---- Gradio UI ----
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="🚦 Traffic Violation Detection",
description="Upload an image",
flagging_mode="never"
)