234313N / app.py
ChilliT's picture
Create app.py
3b2af7b verified
import gradio as gr
import torch
from PIL import Image
from torchvision import models, transforms
# Load the trained object detection model (make sure to change this to your model path)
model = torch.load("path/to/your/best_model.pt")
model.eval()
# Define a function to run inference on an uploaded image
def detect_objects(image):
# Transform the input image to match the model input size and format
transform = transforms.Compose([transforms.ToTensor()])
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
predictions = model(image_tensor) # Run inference
# Process and return the bounding boxes and labels (you may need to adjust based on your model)
return predictions
# Define the Gradio interface
interface = gr.Interface(
fn=detect_objects,
inputs=gr.inputs.Image(type="pil"),
outputs="json",
live=True,
description="Object Detection for Lesions vs Non-Lesions"
)
# Launch the interface in the browser
interface.launch()