openfree commited on
Commit
00869ae
·
verified ·
1 Parent(s): 9162e76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -8,6 +8,8 @@ import numpy as np
8
  import kagglehub
9
  from PIL import Image
10
  from glob import glob
 
 
11
  import matplotlib.pyplot as plt
12
  from matplotlib import patches
13
  from torchvision import transforms as T
@@ -16,6 +18,7 @@ import shutil
16
  import tempfile
17
  from pathlib import Path
18
  import json
 
19
 
20
  # Try to import spaces for Hugging Face Spaces GPU support
21
  try:
@@ -90,7 +93,6 @@ class Visualization:
90
  self.im_paths[data_type] = im_paths
91
 
92
  def plot_single(self, im_path, bboxes):
93
- fig, ax = plt.subplots(figsize=(8, 8))
94
  or_im = np.array(Image.open(im_path).convert("RGB"))
95
  height, width, _ = or_im.shape
96
 
@@ -102,16 +104,19 @@ class Visualization:
102
  x_max = int((x_center + w / 2) * width)
103
  y_max = int((y_center + h / 2) * height)
104
 
105
- color = (random.randint(0, 255)/255, random.randint(0, 255)/255, random.randint(0, 255)/255)
106
  cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max),
107
- color=(int(color[0]*255), int(color[1]*255), int(color[2]*255)), thickness=3)
108
 
109
- ax.imshow(or_im)
110
- ax.axis("off")
111
- ax.set_title(f"Number of objects: {len(bboxes)}")
112
- plt.tight_layout()
 
 
 
113
 
114
- return fig
115
 
116
  def vis_samples(self, data_type, n_samples=4):
117
  if data_type not in self.vis_datas:
@@ -156,7 +161,14 @@ class Visualization:
156
  ha='center', va='bottom', fontsize=10, color='navy')
157
 
158
  plt.tight_layout()
159
- return fig
 
 
 
 
 
 
 
160
 
161
  def download_dataset():
162
  """Download the dataset using kagglehub"""
@@ -429,7 +441,7 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
429
  data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
430
  analyze_btn = gr.Button("Analyze Distribution")
431
 
432
- distribution_plot = gr.Plot(label="Class Distribution")
433
  analysis_status = gr.Textbox(label="Status", interactive=False)
434
 
435
  analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
 
8
  import kagglehub
9
  from PIL import Image
10
  from glob import glob
11
+ import matplotlib
12
+ matplotlib.use('Agg') # Use non-interactive backend
13
  import matplotlib.pyplot as plt
14
  from matplotlib import patches
15
  from torchvision import transforms as T
 
18
  import tempfile
19
  from pathlib import Path
20
  import json
21
+ from io import BytesIO
22
 
23
  # Try to import spaces for Hugging Face Spaces GPU support
24
  try:
 
93
  self.im_paths[data_type] = im_paths
94
 
95
  def plot_single(self, im_path, bboxes):
 
96
  or_im = np.array(Image.open(im_path).convert("RGB"))
97
  height, width, _ = or_im.shape
98
 
 
104
  x_max = int((x_center + w / 2) * width)
105
  y_max = int((y_center + h / 2) * height)
106
 
107
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
108
  cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max),
109
+ color=color, thickness=3)
110
 
111
+ # Add text overlay
112
+ cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30),
113
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
114
+
115
+ # Convert BGR to RGB if needed
116
+ if len(or_im.shape) == 3 and or_im.shape[2] == 3:
117
+ or_im = cv2.cvtColor(or_im, cv2.COLOR_BGR2RGB)
118
 
119
+ return Image.fromarray(or_im)
120
 
121
  def vis_samples(self, data_type, n_samples=4):
122
  if data_type not in self.vis_datas:
 
161
  ha='center', va='bottom', fontsize=10, color='navy')
162
 
163
  plt.tight_layout()
164
+
165
+ # Convert matplotlib figure to PIL Image
166
+ fig.canvas.draw()
167
+ img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
168
+ img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
169
+ plt.close(fig)
170
+
171
+ return Image.fromarray(img_array)
172
 
173
  def download_dataset():
174
  """Download the dataset using kagglehub"""
 
441
  data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
442
  analyze_btn = gr.Button("Analyze Distribution")
443
 
444
+ distribution_plot = gr.Image(label="Class Distribution", type="pil")
445
  analysis_status = gr.Textbox(label="Status", interactive=False)
446
 
447
  analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,