ChilliT commited on
Commit
3b2af7b
·
verified ·
1 Parent(s): c293c27

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import models, transforms
5
+
6
+ # Load the trained object detection model (make sure to change this to your model path)
7
+ model = torch.load("path/to/your/best_model.pt")
8
+ model.eval()
9
+
10
+ # Define a function to run inference on an uploaded image
11
+ def detect_objects(image):
12
+ # Transform the input image to match the model input size and format
13
+ transform = transforms.Compose([transforms.ToTensor()])
14
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
15
+
16
+ with torch.no_grad():
17
+ predictions = model(image_tensor) # Run inference
18
+
19
+ # Process and return the bounding boxes and labels (you may need to adjust based on your model)
20
+ return predictions
21
+
22
+ # Define the Gradio interface
23
+ interface = gr.Interface(
24
+ fn=detect_objects,
25
+ inputs=gr.inputs.Image(type="pil"),
26
+ outputs="json",
27
+ live=True,
28
+ description="Object Detection for Lesions vs Non-Lesions"
29
+ )
30
+
31
+ # Launch the interface in the browser
32
+ interface.launch()