Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,8 @@ import numpy as np
|
|
| 8 |
import kagglehub
|
| 9 |
from PIL import Image
|
| 10 |
from glob import glob
|
|
|
|
|
|
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
from matplotlib import patches
|
| 13 |
from torchvision import transforms as T
|
|
@@ -16,6 +18,7 @@ import shutil
|
|
| 16 |
import tempfile
|
| 17 |
from pathlib import Path
|
| 18 |
import json
|
|
|
|
| 19 |
|
| 20 |
# Try to import spaces for Hugging Face Spaces GPU support
|
| 21 |
try:
|
|
@@ -90,7 +93,6 @@ class Visualization:
|
|
| 90 |
self.im_paths[data_type] = im_paths
|
| 91 |
|
| 92 |
def plot_single(self, im_path, bboxes):
|
| 93 |
-
fig, ax = plt.subplots(figsize=(8, 8))
|
| 94 |
or_im = np.array(Image.open(im_path).convert("RGB"))
|
| 95 |
height, width, _ = or_im.shape
|
| 96 |
|
|
@@ -102,16 +104,19 @@ class Visualization:
|
|
| 102 |
x_max = int((x_center + w / 2) * width)
|
| 103 |
y_max = int((y_center + h / 2) * height)
|
| 104 |
|
| 105 |
-
color = (random.randint(0, 255)
|
| 106 |
cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max),
|
| 107 |
-
color=
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
return
|
| 115 |
|
| 116 |
def vis_samples(self, data_type, n_samples=4):
|
| 117 |
if data_type not in self.vis_datas:
|
|
@@ -156,7 +161,14 @@ class Visualization:
|
|
| 156 |
ha='center', va='bottom', fontsize=10, color='navy')
|
| 157 |
|
| 158 |
plt.tight_layout()
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
def download_dataset():
|
| 162 |
"""Download the dataset using kagglehub"""
|
|
@@ -429,7 +441,7 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
|
|
| 429 |
data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
|
| 430 |
analyze_btn = gr.Button("Analyze Distribution")
|
| 431 |
|
| 432 |
-
distribution_plot = gr.
|
| 433 |
analysis_status = gr.Textbox(label="Status", interactive=False)
|
| 434 |
|
| 435 |
analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
|
|
|
|
| 8 |
import kagglehub
|
| 9 |
from PIL import Image
|
| 10 |
from glob import glob
|
| 11 |
+
import matplotlib
|
| 12 |
+
matplotlib.use('Agg') # Use non-interactive backend
|
| 13 |
import matplotlib.pyplot as plt
|
| 14 |
from matplotlib import patches
|
| 15 |
from torchvision import transforms as T
|
|
|
|
| 18 |
import tempfile
|
| 19 |
from pathlib import Path
|
| 20 |
import json
|
| 21 |
+
from io import BytesIO
|
| 22 |
|
| 23 |
# Try to import spaces for Hugging Face Spaces GPU support
|
| 24 |
try:
|
|
|
|
| 93 |
self.im_paths[data_type] = im_paths
|
| 94 |
|
| 95 |
def plot_single(self, im_path, bboxes):
|
|
|
|
| 96 |
or_im = np.array(Image.open(im_path).convert("RGB"))
|
| 97 |
height, width, _ = or_im.shape
|
| 98 |
|
|
|
|
| 104 |
x_max = int((x_center + w / 2) * width)
|
| 105 |
y_max = int((y_center + h / 2) * height)
|
| 106 |
|
| 107 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 108 |
cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max),
|
| 109 |
+
color=color, thickness=3)
|
| 110 |
|
| 111 |
+
# Add text overlay
|
| 112 |
+
cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30),
|
| 113 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
| 114 |
+
|
| 115 |
+
# Convert BGR to RGB if needed
|
| 116 |
+
if len(or_im.shape) == 3 and or_im.shape[2] == 3:
|
| 117 |
+
or_im = cv2.cvtColor(or_im, cv2.COLOR_BGR2RGB)
|
| 118 |
|
| 119 |
+
return Image.fromarray(or_im)
|
| 120 |
|
| 121 |
def vis_samples(self, data_type, n_samples=4):
|
| 122 |
if data_type not in self.vis_datas:
|
|
|
|
| 161 |
ha='center', va='bottom', fontsize=10, color='navy')
|
| 162 |
|
| 163 |
plt.tight_layout()
|
| 164 |
+
|
| 165 |
+
# Convert matplotlib figure to PIL Image
|
| 166 |
+
fig.canvas.draw()
|
| 167 |
+
img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
| 168 |
+
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
| 169 |
+
plt.close(fig)
|
| 170 |
+
|
| 171 |
+
return Image.fromarray(img_array)
|
| 172 |
|
| 173 |
def download_dataset():
|
| 174 |
"""Download the dataset using kagglehub"""
|
|
|
|
| 441 |
data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
|
| 442 |
analyze_btn = gr.Button("Analyze Distribution")
|
| 443 |
|
| 444 |
+
distribution_plot = gr.Image(label="Class Distribution", type="pil")
|
| 445 |
analysis_status = gr.Textbox(label="Status", interactive=False)
|
| 446 |
|
| 447 |
analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
|