openfree commited on
Commit
aa445bb
Β·
verified Β·
1 Parent(s): d1ded04

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +407 -0
app.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import yaml
4
+ import torch
5
+ import random
6
+ import gradio as gr
7
+ import numpy as np
8
+ import kagglehub
9
+ from PIL import Image
10
+ from glob import glob
11
+ import matplotlib.pyplot as plt
12
+ from matplotlib import patches
13
+ from torchvision import transforms as T
14
+ from ultralytics import YOLO
15
+ import shutil
16
+ import tempfile
17
+ from pathlib import Path
18
+
19
+ # Set Kaggle API credentials from environment variable
20
+ if os.getenv("KDATA_API"):
21
+ kaggle_key = os.getenv("KDATA_API")
22
+ # Parse the key if it's in JSON format
23
+ if "{" in kaggle_key:
24
+ import json
25
+ key_data = json.loads(kaggle_key)
26
+ os.environ["KAGGLE_USERNAME"] = key_data.get("username", "")
27
+ os.environ["KAGGLE_KEY"] = key_data.get("key", "")
28
+
29
+ # Global variables
30
+ model = None
31
+ dataset_path = None
32
+ training_in_progress = False
33
+
34
+ class Visualization:
35
+ def __init__(self, root, data_types, n_ims, rows, cmap=None):
36
+ self.n_ims, self.rows = n_ims, rows
37
+ self.cmap, self.data_types = cmap, data_types
38
+ self.colors = ["firebrick", "darkorange", "blueviolet"]
39
+ self.root = root
40
+
41
+ self.get_cls_names()
42
+ self.get_bboxes()
43
+
44
+ def get_cls_names(self):
45
+ with open(f"{self.root}/data.yaml", 'r') as file:
46
+ data = yaml.safe_load(file)
47
+ class_names = data['names']
48
+ self.class_dict = {index: name for index, name in enumerate(class_names)}
49
+
50
+ def get_bboxes(self):
51
+ self.vis_datas, self.analysis_datas, self.im_paths = {}, {}, {}
52
+ for data_type in self.data_types:
53
+ all_bboxes, all_analysis_datas = [], {}
54
+ im_paths = glob(f"{self.root}/{data_type}/images/*")
55
+
56
+ for idx, im_path in enumerate(im_paths):
57
+ bboxes = []
58
+ im_ext = os.path.splitext(im_path)[-1]
59
+ lbl_path = im_path.replace(im_ext, ".txt")
60
+ lbl_path = lbl_path.replace(f"{data_type}/images", f"{data_type}/labels")
61
+ if not os.path.isfile(lbl_path):
62
+ continue
63
+ meta_data = open(lbl_path).readlines()
64
+ for data in meta_data:
65
+ parts = data.strip().split()[:5]
66
+ cls_name = self.class_dict[int(parts[0])]
67
+ bboxes.append([cls_name] + [float(x) for x in parts[1:]])
68
+ if cls_name not in all_analysis_datas:
69
+ all_analysis_datas[cls_name] = 1
70
+ else:
71
+ all_analysis_datas[cls_name] += 1
72
+ all_bboxes.append(bboxes)
73
+
74
+ self.vis_datas[data_type] = all_bboxes
75
+ self.analysis_datas[data_type] = all_analysis_datas
76
+ self.im_paths[data_type] = im_paths
77
+
78
+ def plot_single(self, im_path, bboxes):
79
+ fig, ax = plt.subplots(figsize=(8, 8))
80
+ or_im = np.array(Image.open(im_path).convert("RGB"))
81
+ height, width, _ = or_im.shape
82
+
83
+ for bbox in bboxes:
84
+ class_id, x_center, y_center, w, h = bbox
85
+
86
+ x_min = int((x_center - w / 2) * width)
87
+ y_min = int((y_center - h / 2) * height)
88
+ x_max = int((x_center + w / 2) * width)
89
+ y_max = int((y_center + h / 2) * height)
90
+
91
+ color = (random.randint(0, 255)/255, random.randint(0, 255)/255, random.randint(0, 255)/255)
92
+ cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max),
93
+ color=(int(color[0]*255), int(color[1]*255), int(color[2]*255)), thickness=3)
94
+
95
+ ax.imshow(or_im)
96
+ ax.axis("off")
97
+ ax.set_title(f"Number of objects: {len(bboxes)}")
98
+
99
+ return fig
100
+
101
+ def vis_samples(self, data_type, n_samples=4):
102
+ if data_type not in self.vis_datas:
103
+ return None
104
+
105
+ indices = [random.randint(0, len(self.vis_datas[data_type]) - 1)
106
+ for _ in range(min(n_samples, len(self.vis_datas[data_type])))]
107
+
108
+ figs = []
109
+ for idx in indices:
110
+ im_path = self.im_paths[data_type][idx]
111
+ bboxes = self.vis_datas[data_type][idx]
112
+ fig = self.plot_single(im_path, bboxes)
113
+ figs.append(fig)
114
+
115
+ return figs
116
+
117
+ def data_analysis(self, data_type):
118
+ if data_type not in self.analysis_datas:
119
+ return None
120
+
121
+ plt.style.use('default')
122
+ fig, ax = plt.subplots(figsize=(12, 6))
123
+
124
+ cls_names = list(self.analysis_datas[data_type].keys())
125
+ counts = list(self.analysis_datas[data_type].values())
126
+
127
+ color_map = {"train": "firebrick", "valid": "darkorange", "test": "blueviolet"}
128
+ color = color_map.get(data_type, "steelblue")
129
+
130
+ indices = np.arange(len(counts))
131
+ bars = ax.bar(indices, counts, 0.7, color=color)
132
+
133
+ ax.set_xlabel("Class Names", fontsize=12)
134
+ ax.set_xticks(indices)
135
+ ax.set_xticklabels(cls_names, rotation=45, ha='right')
136
+ ax.set_ylabel("Data Counts", fontsize=12)
137
+ ax.set_title(f"{data_type.upper()} Dataset Class Distribution", fontsize=14)
138
+
139
+ for i, (bar, v) in enumerate(zip(bars, counts)):
140
+ ax.text(bar.get_x() + bar.get_width()/2, v + 1, str(v),
141
+ ha='center', va='bottom', fontsize=10, color='navy')
142
+
143
+ plt.tight_layout()
144
+ return fig
145
+
146
+ def download_dataset():
147
+ """Download the dataset using kagglehub"""
148
+ global dataset_path
149
+ try:
150
+ dataset_path = kagglehub.dataset_download("orvile/x-ray-baggage-anomaly-detection")
151
+ return f"Dataset downloaded successfully to: {dataset_path}"
152
+ except Exception as e:
153
+ return f"Error downloading dataset: {str(e)}"
154
+
155
+ def visualize_data(data_type, num_samples):
156
+ """Visualize sample images from the dataset"""
157
+ if dataset_path is None:
158
+ return None, "Please download the dataset first!"
159
+
160
+ try:
161
+ vis = Visualization(root=dataset_path, data_types=[data_type],
162
+ n_ims=num_samples, rows=2, cmap="rgb")
163
+ figs = vis.vis_samples(data_type, num_samples)
164
+ return figs, f"Showing {len(figs)} samples from {data_type} dataset"
165
+ except Exception as e:
166
+ return None, f"Error visualizing data: {str(e)}"
167
+
168
+ def analyze_class_distribution(data_type):
169
+ """Analyze class distribution in the dataset"""
170
+ if dataset_path is None:
171
+ return None, "Please download the dataset first!"
172
+
173
+ try:
174
+ vis = Visualization(root=dataset_path, data_types=[data_type],
175
+ n_ims=20, rows=5, cmap="rgb")
176
+ fig = vis.data_analysis(data_type)
177
+ return fig, f"Class distribution for {data_type} dataset"
178
+ except Exception as e:
179
+ return None, f"Error analyzing data: {str(e)}"
180
+
181
+ def train_model(epochs, batch_size, img_size, device_selection):
182
+ """Train YOLOv11 model"""
183
+ global model, training_in_progress
184
+
185
+ if dataset_path is None:
186
+ return None, "Please download the dataset first!"
187
+
188
+ if training_in_progress:
189
+ return None, "Training already in progress!"
190
+
191
+ training_in_progress = True
192
+
193
+ try:
194
+ # Determine device
195
+ if device_selection == "Auto":
196
+ device = 0 if torch.cuda.is_available() else "cpu"
197
+ elif device_selection == "CPU":
198
+ device = "cpu"
199
+ else:
200
+ device = 0
201
+
202
+ # Initialize model
203
+ model = YOLO("yolo11n.pt")
204
+
205
+ # Train model
206
+ results = model.train(
207
+ data=f"{dataset_path}/data.yaml",
208
+ epochs=epochs,
209
+ imgsz=img_size,
210
+ batch=batch_size,
211
+ device=device,
212
+ project="xray_detection",
213
+ name="train",
214
+ exist_ok=True,
215
+ verbose=True
216
+ )
217
+
218
+ # Plot training results
219
+ results_path = "xray_detection/train"
220
+ plots = []
221
+
222
+ for plot_file in ["results.png", "confusion_matrix.png", "val_batch0_pred.jpg"]:
223
+ plot_path = os.path.join(results_path, plot_file)
224
+ if os.path.exists(plot_path):
225
+ plots.append(Image.open(plot_path))
226
+
227
+ training_in_progress = False
228
+ return plots, f"Training completed! Model saved to {results_path}"
229
+
230
+ except Exception as e:
231
+ training_in_progress = False
232
+ return None, f"Error during training: {str(e)}"
233
+
234
+ def run_inference(input_image, conf_threshold):
235
+ """Run inference on a single image"""
236
+ global model
237
+
238
+ if model is None:
239
+ return None, "Please train the model first or load a pre-trained model!"
240
+
241
+ try:
242
+ # Save the input image temporarily
243
+ temp_path = "temp_inference.jpg"
244
+ input_image.save(temp_path)
245
+
246
+ # Run inference
247
+ results = model(temp_path, conf=conf_threshold, verbose=False)
248
+
249
+ # Draw results
250
+ annotated_image = results[0].plot()
251
+
252
+ # Get detection info
253
+ detections = []
254
+ for r in results:
255
+ for box in r.boxes:
256
+ cls = int(box.cls)
257
+ conf = float(box.conf)
258
+ cls_name = model.names[cls]
259
+ detections.append(f"{cls_name}: {conf:.2f}")
260
+
261
+ # Clean up
262
+ os.remove(temp_path)
263
+
264
+ detection_text = "\n".join(detections) if detections else "No objects detected"
265
+
266
+ return Image.fromarray(annotated_image), f"Detections:\n{detection_text}"
267
+
268
+ except Exception as e:
269
+ return None, f"Error during inference: {str(e)}"
270
+
271
+ def batch_inference(data_type, num_images):
272
+ """Run inference on multiple images from test set"""
273
+ global model
274
+
275
+ if model is None:
276
+ return None, "Please train the model first!"
277
+
278
+ if dataset_path is None:
279
+ return None, "Please download the dataset first!"
280
+
281
+ try:
282
+ image_dir = f"{dataset_path}/{data_type}/images"
283
+ image_files = glob(f"{image_dir}/*")[:num_images]
284
+
285
+ results_images = []
286
+
287
+ for img_path in image_files:
288
+ results = model(img_path, verbose=False)
289
+ annotated = results[0].plot()
290
+ results_images.append(Image.fromarray(annotated))
291
+
292
+ return results_images, f"Processed {len(results_images)} images from {data_type} dataset"
293
+
294
+ except Exception as e:
295
+ return None, f"Error during batch inference: {str(e)}"
296
+
297
+ def load_pretrained_model(model_path):
298
+ """Load a pre-trained model"""
299
+ global model
300
+ try:
301
+ model = YOLO(model_path)
302
+ return f"Model loaded successfully from {model_path}"
303
+ except Exception as e:
304
+ return f"Error loading model: {str(e)}"
305
+
306
+ # Create Gradio interface
307
+ with gr.Blocks(title="X-ray Baggage Anomaly Detection") as demo:
308
+ gr.Markdown("""
309
+ # 🎯 X-ray Baggage Anomaly Detection with YOLOv11
310
+
311
+ This application allows you to:
312
+ 1. Download and visualize the X-ray baggage dataset
313
+ 2. Analyze class distributions
314
+ 3. Train a YOLOv11 model for object detection
315
+ 4. Run inference on new images
316
+ """)
317
+
318
+ with gr.Tab("πŸ“Š Dataset"):
319
+ with gr.Row():
320
+ download_btn = gr.Button("Download Dataset", variant="primary")
321
+ download_status = gr.Textbox(label="Status", interactive=False)
322
+
323
+ download_btn.click(download_dataset, outputs=download_status)
324
+
325
+ gr.Markdown("### Visualize Dataset Samples")
326
+ with gr.Row():
327
+ data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
328
+ num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples")
329
+ viz_btn = gr.Button("Visualize Samples")
330
+
331
+ viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto")
332
+ viz_status = gr.Textbox(label="Status", interactive=False)
333
+
334
+ viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples],
335
+ outputs=[viz_gallery, viz_status])
336
+
337
+ gr.Markdown("### Analyze Class Distribution")
338
+ with gr.Row():
339
+ data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
340
+ analyze_btn = gr.Button("Analyze Distribution")
341
+
342
+ distribution_plot = gr.Plot(label="Class Distribution")
343
+ analysis_status = gr.Textbox(label="Status", interactive=False)
344
+
345
+ analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
346
+ outputs=[distribution_plot, analysis_status])
347
+
348
+ with gr.Tab("πŸš€ Training"):
349
+ gr.Markdown("### Train YOLOv11 Model")
350
+
351
+ with gr.Row():
352
+ epochs_input = gr.Slider(1, 50, 10, step=1, label="Epochs")
353
+ batch_size_input = gr.Slider(8, 64, 16, step=8, label="Batch Size")
354
+ img_size_input = gr.Slider(320, 640, 480, step=32, label="Image Size")
355
+ device_input = gr.Radio(["Auto", "GPU", "CPU"], value="Auto", label="Device")
356
+
357
+ train_btn = gr.Button("Start Training", variant="primary")
358
+
359
+ training_gallery = gr.Gallery(label="Training Results", columns=3, height="auto")
360
+ training_status = gr.Textbox(label="Training Status", interactive=False)
361
+
362
+ train_btn.click(train_model,
363
+ inputs=[epochs_input, batch_size_input, img_size_input, device_input],
364
+ outputs=[training_gallery, training_status])
365
+
366
+ gr.Markdown("### Load Pre-trained Model")
367
+ with gr.Row():
368
+ model_path_input = gr.Textbox(label="Model Path", value="yolo11n.pt")
369
+ load_model_btn = gr.Button("Load Model")
370
+ load_status = gr.Textbox(label="Status", interactive=False)
371
+
372
+ load_model_btn.click(load_pretrained_model, inputs=model_path_input, outputs=load_status)
373
+
374
+ with gr.Tab("πŸ” Inference"):
375
+ gr.Markdown("### Single Image Inference")
376
+
377
+ with gr.Row():
378
+ input_image = gr.Image(type="pil", label="Upload Image")
379
+ conf_threshold = gr.Slider(0.1, 0.9, 0.5, step=0.05, label="Confidence Threshold")
380
+
381
+ inference_btn = gr.Button("Run Detection", variant="primary")
382
+
383
+ with gr.Row():
384
+ output_image = gr.Image(type="pil", label="Detection Result")
385
+ detection_info = gr.Textbox(label="Detection Info", lines=5)
386
+
387
+ inference_btn.click(run_inference,
388
+ inputs=[input_image, conf_threshold],
389
+ outputs=[output_image, detection_info])
390
+
391
+ gr.Markdown("### Batch Inference")
392
+
393
+ with gr.Row():
394
+ batch_data_type = gr.Dropdown(["test", "valid"], value="test", label="Dataset Type")
395
+ batch_num_images = gr.Slider(1, 10, 5, step=1, label="Number of Images")
396
+ batch_btn = gr.Button("Run Batch Inference")
397
+
398
+ batch_gallery = gr.Gallery(label="Batch Results", columns=3, height="auto")
399
+ batch_status = gr.Textbox(label="Status", interactive=False)
400
+
401
+ batch_btn.click(batch_inference,
402
+ inputs=[batch_data_type, batch_num_images],
403
+ outputs=[batch_gallery, batch_status])
404
+
405
+ # Launch the app
406
+ if __name__ == "__main__":
407
+ demo.launch(share=True)