CircleStar commited on
Commit
733869b
·
verified ·
1 Parent(s): ca9c54c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -439
app.py CHANGED
@@ -1,394 +1,14 @@
1
- import os
2
  import json
3
- import time
4
- import random
5
- from datetime import datetime
6
- from typing import List, Tuple
7
 
8
  import spaces
9
  import gradio as gr
10
- import torch
11
- import torch.nn as nn
12
- import torch.optim as optim
13
- from torch.utils.data import DataLoader, random_split
14
- from torchvision import datasets, transforms
15
- from PIL import Image
16
-
17
-
18
- # ============================================================
19
- # Paths / basic config
20
- # ============================================================
21
- BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd()
22
- DATA_DIR = os.path.join(BASE_DIR, "data")
23
- MODEL_DIR = os.path.join(BASE_DIR, "saved_models")
24
- META_DIR = os.path.join(BASE_DIR, "saved_models_meta")
25
-
26
- os.makedirs(DATA_DIR, exist_ok=True)
27
- os.makedirs(MODEL_DIR, exist_ok=True)
28
- os.makedirs(META_DIR, exist_ok=True)
29
-
30
- CLASS_NAMES = [str(i) for i in range(10)]
31
-
32
-
33
- # ============================================================
34
- # Model
35
- # ============================================================
36
- class SimpleCNN(nn.Module):
37
- def __init__(
38
- self,
39
- conv1_channels: int = 16,
40
- conv2_channels: int = 32,
41
- kernel_size: int = 3,
42
- dropout: float = 0.2,
43
- fc_dim: int = 128,
44
- ):
45
- super().__init__()
46
- padding = kernel_size // 2
47
-
48
- self.features = nn.Sequential(
49
- nn.Conv2d(1, conv1_channels, kernel_size=kernel_size, padding=padding),
50
- nn.ReLU(),
51
- nn.MaxPool2d(2),
52
-
53
- nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding),
54
- nn.ReLU(),
55
- nn.MaxPool2d(2),
56
- )
57
-
58
- flattened_dim = conv2_channels * 7 * 7 # 28x28 -> 14x14 -> 7x7
59
-
60
- self.classifier = nn.Sequential(
61
- nn.Flatten(),
62
- nn.Linear(flattened_dim, fc_dim),
63
- nn.ReLU(),
64
- nn.Dropout(dropout),
65
- nn.Linear(fc_dim, 10),
66
- )
67
-
68
- def forward(self, x):
69
- x = self.features(x)
70
- x = self.classifier(x)
71
- return x
72
-
73
-
74
- # ============================================================
75
- # Dataset helpers
76
- # ============================================================
77
- def get_datasets(dataset_name: str):
78
- transform = transforms.Compose(
79
- [
80
- transforms.ToTensor(),
81
- transforms.Normalize((0.5,), (0.5,))
82
- ]
83
- )
84
-
85
- if dataset_name == "MNIST":
86
- train_dataset = datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform)
87
- test_dataset = datasets.MNIST(DATA_DIR, train=False, download=True, transform=transform)
88
- elif dataset_name == "FashionMNIST":
89
- train_dataset = datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform)
90
- test_dataset = datasets.FashionMNIST(DATA_DIR, train=False, download=True, transform=transform)
91
- else:
92
- raise ValueError(f"Unsupported dataset: {dataset_name}")
93
-
94
- return train_dataset, test_dataset
95
-
96
-
97
- def make_loaders(dataset_name: str, batch_size: int, val_ratio: float = 0.1):
98
- train_dataset, test_dataset = get_datasets(dataset_name)
99
-
100
- val_size = int(len(train_dataset) * val_ratio)
101
- train_size = len(train_dataset) - val_size
102
-
103
- train_subset, val_subset = random_split(train_dataset, [train_size, val_size])
104
-
105
- train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
106
- val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
107
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
108
-
109
- return train_loader, val_loader, test_loader
110
-
111
-
112
- # ============================================================
113
- # Model save/load helpers
114
- # ============================================================
115
- def model_weight_path(model_name: str) -> str:
116
- return os.path.join(MODEL_DIR, f"{model_name}.pt")
117
-
118
-
119
- def model_meta_path(model_name: str) -> str:
120
- return os.path.join(META_DIR, f"{model_name}.json")
121
-
122
-
123
- def list_saved_models() -> List[str]:
124
- names = []
125
- for fn in os.listdir(META_DIR):
126
- if fn.endswith(".json"):
127
- names.append(fn[:-5])
128
- names.sort(reverse=True)
129
- return names
130
-
131
 
132
- def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
133
- cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
134
- torch.save(cpu_state_dict, model_weight_path(model_name))
135
-
136
- payload = {
137
- "model_name": model_name,
138
- "config": config,
139
- "training_summary": training_summary,
140
- "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
141
- }
142
- with open(model_meta_path(model_name), "w", encoding="utf-8") as f:
143
- json.dump(payload, f, indent=2, ensure_ascii=False)
144
-
145
-
146
- def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
147
- meta_file = model_meta_path(model_name)
148
- weight_file = model_weight_path(model_name)
149
-
150
- if not os.path.exists(meta_file):
151
- raise FileNotFoundError(f"Metadata not found for model: {model_name}")
152
- if not os.path.exists(weight_file):
153
- raise FileNotFoundError(f"Weights not found for model: {model_name}")
154
-
155
- with open(meta_file, "r", encoding="utf-8") as f:
156
- meta = json.load(f)
157
-
158
- cfg = meta["config"]
159
-
160
- model = SimpleCNN(
161
- conv1_channels=cfg["conv1_channels"],
162
- conv2_channels=cfg["conv2_channels"],
163
- kernel_size=cfg["kernel_size"],
164
- dropout=cfg["dropout"],
165
- fc_dim=cfg["fc_dim"],
166
- )
167
-
168
- state_dict = torch.load(weight_file, map_location="cpu")
169
- model.load_state_dict(state_dict)
170
- model.to(device)
171
- model.eval()
172
- return model, meta
173
-
174
-
175
- # ============================================================
176
- # ZeroGPU helpers
177
- # ============================================================
178
- def get_runtime_device() -> torch.device:
179
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
 
181
 
182
  @spaces.GPU(duration=120)
183
- def _train_on_gpu(
184
- dataset_name: str,
185
- conv1_channels: int,
186
- conv2_channels: int,
187
- kernel_size: int,
188
- dropout: float,
189
- fc_dim: int,
190
- learning_rate: float,
191
- batch_size: int,
192
- epochs: int,
193
- model_tag: str,
194
- ):
195
- device = get_runtime_device()
196
-
197
- train_loader, val_loader, test_loader = make_loaders(dataset_name, batch_size)
198
-
199
- model = SimpleCNN(
200
- conv1_channels=conv1_channels,
201
- conv2_channels=conv2_channels,
202
- kernel_size=kernel_size,
203
- dropout=dropout,
204
- fc_dim=fc_dim,
205
- ).to(device)
206
-
207
- criterion = nn.CrossEntropyLoss()
208
- optimizer = optim.Adam(model.parameters(), lr=learning_rate)
209
-
210
- history = []
211
- logs = []
212
- start_time = time.time()
213
-
214
- def evaluate(loader):
215
- model.eval()
216
- total_loss = 0.0
217
- total = 0
218
- correct = 0
219
-
220
- with torch.no_grad():
221
- for images, labels in loader:
222
- images, labels = images.to(device), labels.to(device)
223
- outputs = model(images)
224
- loss = criterion(outputs, labels)
225
-
226
- total_loss += loss.item() * images.size(0)
227
- preds = outputs.argmax(dim=1)
228
- correct += (preds == labels).sum().item()
229
- total += labels.size(0)
230
-
231
- avg_loss = total_loss / total if total else 0.0
232
- acc = correct / total if total else 0.0
233
- return avg_loss, acc
234
-
235
- for epoch in range(1, epochs + 1):
236
- model.train()
237
- running_loss = 0.0
238
- total = 0
239
- correct = 0
240
-
241
- for images, labels in train_loader:
242
- images, labels = images.to(device), labels.to(device)
243
-
244
- optimizer.zero_grad()
245
- outputs = model(images)
246
- loss = criterion(outputs, labels)
247
- loss.backward()
248
- optimizer.step()
249
-
250
- running_loss += loss.item() * images.size(0)
251
- preds = outputs.argmax(dim=1)
252
- correct += (preds == labels).sum().item()
253
- total += labels.size(0)
254
-
255
- train_loss = running_loss / total if total else 0.0
256
- train_acc = correct / total if total else 0.0
257
- val_loss, val_acc = evaluate(val_loader)
258
-
259
- row = {
260
- "epoch": epoch,
261
- "train_loss": round(train_loss, 4),
262
- "train_acc": round(train_acc, 4),
263
- "val_loss": round(val_loss, 4),
264
- "val_acc": round(val_acc, 4),
265
- }
266
- history.append(row)
267
-
268
- logs.append(
269
- f"Epoch {epoch}/{epochs} | "
270
- f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
271
- f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
272
- )
273
-
274
- test_loss, test_acc = evaluate(test_loader)
275
- elapsed = time.time() - start_time
276
-
277
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
278
- safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else dataset_name.lower()
279
- model_name = f"{safe_tag}_{timestamp}"
280
-
281
- config = {
282
- "dataset_name": dataset_name,
283
- "conv1_channels": conv1_channels,
284
- "conv2_channels": conv2_channels,
285
- "kernel_size": kernel_size,
286
- "dropout": dropout,
287
- "fc_dim": fc_dim,
288
- "learning_rate": learning_rate,
289
- "batch_size": batch_size,
290
- "epochs": epochs,
291
- }
292
-
293
- training_summary = {
294
- "final_train_loss": history[-1]["train_loss"] if history else None,
295
- "final_train_acc": history[-1]["train_acc"] if history else None,
296
- "final_val_loss": history[-1]["val_loss"] if history else None,
297
- "final_val_acc": history[-1]["val_acc"] if history else None,
298
- "test_loss": round(test_loss, 4),
299
- "test_acc": round(test_acc, 4),
300
- "elapsed_seconds": round(elapsed, 2),
301
- "device": str(device),
302
- }
303
-
304
- save_model(model, model_name, config, training_summary)
305
-
306
- logs.append("")
307
- logs.append("Training finished.")
308
- logs.append(f"Saved model: {model_name}")
309
- logs.append(f"Device: {device}")
310
- logs.append(f"Test loss: {test_loss:.4f}")
311
- logs.append(f"Test accuracy: {test_acc:.4f}")
312
- logs.append(f"Elapsed time: {elapsed:.1f}s")
313
-
314
- return "\n".join(logs), history, training_summary, model_name
315
-
316
-
317
- @spaces.GPU(duration=60)
318
- def _predict_uploaded_image_gpu(model_name: str, image: Image.Image):
319
- if not model_name:
320
- return "Please select a model.", None
321
-
322
- if image is None:
323
- return "Please upload an image.", None
324
-
325
- device = get_runtime_device()
326
- model, meta = load_model(model_name, device)
327
-
328
- transform = transforms.Compose(
329
- [
330
- transforms.Grayscale(num_output_channels=1),
331
- transforms.Resize((28, 28)),
332
- transforms.ToTensor(),
333
- transforms.Normalize((0.5,), (0.5,))
334
- ]
335
- )
336
-
337
- tensor = transform(image).unsqueeze(0).to(device)
338
-
339
- with torch.no_grad():
340
- logits = model(tensor)
341
- probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu().tolist()
342
- pred_idx = int(torch.argmax(logits, dim=1).item())
343
-
344
- result_text = (
345
- f"Prediction: {CLASS_NAMES[pred_idx]}\n"
346
- f"Confidence: {max(probs):.4f}\n\n"
347
- f"Model: {model_name}\n"
348
- f"Dataset: {meta['config']['dataset_name']}\n"
349
- f"Runtime device: {device}"
350
- )
351
- prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)}
352
- return result_text, prob_dict
353
-
354
-
355
- @spaces.GPU(duration=60)
356
- def _test_random_sample_gpu(model_name: str):
357
- if not model_name:
358
- return None, "Please select a model.", None
359
-
360
- device = get_runtime_device()
361
- model, meta = load_model(model_name, device)
362
- dataset_name = meta["config"]["dataset_name"]
363
-
364
- _, test_dataset = get_datasets(dataset_name)
365
- idx = random.randint(0, len(test_dataset) - 1)
366
- image_tensor, label = test_dataset[idx]
367
-
368
- with torch.no_grad():
369
- logits = model(image_tensor.unsqueeze(0).to(device))
370
- probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu().tolist()
371
- pred_idx = int(torch.argmax(logits, dim=1).item())
372
-
373
- display_img = image_tensor.squeeze(0).cpu().numpy()
374
-
375
- result_text = (
376
- f"Random test sample\n"
377
- f"Ground truth: {label}\n"
378
- f"Prediction: {pred_idx}\n"
379
- f"Confidence: {max(probs):.4f}\n"
380
- f"Model dataset: {dataset_name}\n"
381
- f"Runtime device: {device}"
382
- )
383
- prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)}
384
- return display_img, result_text, prob_dict
385
-
386
-
387
- # ============================================================
388
- # UI callbacks
389
- # ============================================================
390
  def train_callback(
391
- dataset_name,
392
  conv1_channels,
393
  conv2_channels,
394
  kernel_size,
@@ -400,8 +20,7 @@ def train_callback(
400
  model_tag,
401
  ):
402
  try:
403
- logs, history, summary, model_name = _train_on_gpu(
404
- dataset_name,
405
  int(conv1_channels),
406
  int(conv2_channels),
407
  int(kernel_size),
@@ -412,112 +31,136 @@ def train_callback(
412
  int(epochs),
413
  model_tag,
414
  )
 
415
  models = list_saved_models()
416
  selected = model_name if model_name in models else (models[0] if models else None)
 
417
  return logs, history, summary, gr.update(choices=models, value=selected)
 
418
  except Exception as e:
419
- return f"Training failed:\n{str(e)}", None, None, gr.update()
420
 
421
 
 
422
  def predict_uploaded_image_callback(model_name, image):
423
  try:
424
- return _predict_uploaded_image_gpu(model_name, image)
425
  except Exception as e:
426
- return f"Prediction failed:\n{str(e)}", None
427
 
428
 
 
429
  def test_random_sample_callback(model_name):
430
  try:
431
- return _test_random_sample_gpu(model_name)
432
  except Exception as e:
433
- return None, f"Random test failed:\n{str(e)}", None
 
 
 
 
 
434
 
435
 
436
  def get_model_info(model_name: str):
437
  if not model_name:
438
- return {"message": "No model selected."}
439
 
440
  meta_file = model_meta_path(model_name)
441
- if not os.path.exists(meta_file):
442
- return {"message": "Metadata not found."}
443
-
444
- with open(meta_file, "r", encoding="utf-8") as f:
445
- meta = json.load(f)
446
- return meta
447
 
448
-
449
- def refresh_models_dropdown():
450
- models = list_saved_models()
451
- return gr.update(choices=models, value=models[0] if models else None)
 
452
 
453
 
454
- # ============================================================
455
- # UI
456
- # ============================================================
457
  initial_models = list_saved_models()
458
 
459
- with gr.Blocks(title="Image Classification") as demo:
460
- gr.Markdown("# Image Classification")
 
461
  gr.Markdown(
462
- "Train a simple CNN on MNIST or FashionMNIST, then test saved models "
463
- "with an uploaded image or a random sample."
 
464
  )
465
 
466
  with gr.Tabs():
467
- with gr.Tab("Train"):
468
  with gr.Row():
469
  with gr.Column():
470
- dataset_name = gr.Dropdown(
471
- choices=["MNIST", "FashionMNIST"],
472
- value="MNIST",
473
- label="Dataset",
 
 
 
474
  )
475
- conv1_channels = gr.Slider(8, 64, value=16, step=8, label="Conv1 Channels")
476
- conv2_channels = gr.Slider(16, 128, value=32, step=16, label="Conv2 Channels")
477
- kernel_size = gr.Dropdown(choices=[3, 5], value=3, label="Kernel Size")
478
- dropout = gr.Slider(0.0, 0.7, value=0.2, step=0.05, label="Dropout")
479
- fc_dim = gr.Slider(32, 256, value=128, step=32, label="FC Hidden Dimension")
480
- learning_rate = gr.Number(value=0.001, label="Learning Rate")
481
- batch_size = gr.Dropdown(choices=[32, 64, 128, 256], value=64, label="Batch Size")
482
- epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
483
- model_tag = gr.Textbox(label="Model Tag", placeholder="e.g. mnist_demo")
484
- train_btn = gr.Button("Start Training", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  with gr.Column():
487
- train_status = gr.Textbox(label="Training Log", lines=18)
488
- train_history = gr.JSON(label="Training History")
489
- train_summary = gr.JSON(label="Training Summary")
490
 
491
- with gr.Tab("Test"):
492
  with gr.Row():
493
  with gr.Column():
 
 
494
  model_selector = gr.Dropdown(
495
  choices=initial_models,
496
  value=initial_models[0] if initial_models else None,
497
- label="Select Saved Model",
498
  )
499
- refresh_btn = gr.Button("Refresh Model List")
500
- load_info_btn = gr.Button("Show Model Info")
501
- model_info = gr.JSON(label="Model Metadata")
502
 
503
  with gr.Column():
504
- upload_image = gr.Image(type="pil", label="Upload Image")
505
- predict_btn = gr.Button("Predict Uploaded Image", variant="primary")
506
- predict_text = gr.Textbox(label="Prediction Result", lines=7)
507
- predict_probs = gr.Label(label="Class Probabilities")
 
 
508
 
509
  with gr.Row():
510
- random_test_btn = gr.Button("Test Random Sample")
511
 
512
  with gr.Row():
513
- random_sample_image = gr.Image(type="numpy", label="Random Test Image")
514
- random_sample_text = gr.Textbox(label="Random Sample Result", lines=7)
515
- random_sample_probs = gr.Label(label="Random Sample Probabilities")
516
 
517
  train_btn.click(
518
  fn=train_callback,
519
  inputs=[
520
- dataset_name,
521
  conv1_channels,
522
  conv2_channels,
523
  kernel_size,
 
 
1
  import json
 
 
 
 
2
 
3
  import spaces
4
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ from train_utils import train_model, list_saved_models, model_meta_path
7
+ from predict_utils import predict_uploaded_image, test_random_sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def train_callback(
 
12
  conv1_channels,
13
  conv2_channels,
14
  kernel_size,
 
20
  model_tag,
21
  ):
22
  try:
23
+ logs, history, summary, model_name = train_model(
 
24
  int(conv1_channels),
25
  int(conv2_channels),
26
  int(kernel_size),
 
31
  int(epochs),
32
  model_tag,
33
  )
34
+
35
  models = list_saved_models()
36
  selected = model_name if model_name in models else (models[0] if models else None)
37
+
38
  return logs, history, summary, gr.update(choices=models, value=selected)
39
+
40
  except Exception as e:
41
+ return f"Échec de l’entraînement :\n{str(e)}", None, None, gr.update()
42
 
43
 
44
+ @spaces.GPU(duration=60)
45
  def predict_uploaded_image_callback(model_name, image):
46
  try:
47
+ return predict_uploaded_image(model_name, image)
48
  except Exception as e:
49
+ return f"Échec de la prédiction :\n{str(e)}", None
50
 
51
 
52
+ @spaces.GPU(duration=60)
53
  def test_random_sample_callback(model_name):
54
  try:
55
+ return test_random_sample(model_name)
56
  except Exception as e:
57
+ return None, f"Échec du test aléatoire :\n{str(e)}", None
58
+
59
+
60
+ def refresh_models_dropdown():
61
+ models = list_saved_models()
62
+ return gr.update(choices=models, value=models[0] if models else None)
63
 
64
 
65
  def get_model_info(model_name: str):
66
  if not model_name:
67
+ return {"message": "Aucun modèle sélectionné."}
68
 
69
  meta_file = model_meta_path(model_name)
 
 
 
 
 
 
70
 
71
+ try:
72
+ with open(meta_file, "r", encoding="utf-8") as f:
73
+ return json.load(f)
74
+ except FileNotFoundError:
75
+ return {"message": "Métadonnées introuvables."}
76
 
77
 
 
 
 
78
  initial_models = list_saved_models()
79
 
80
+
81
+ with gr.Blocks(title="Classification d’images microscopiques") as demo:
82
+ gr.Markdown("# Classification d’images microscopiques de charbons de bois")
83
  gr.Markdown(
84
+ "Cette application permet d’entraîner un réseau de neurones convolutif simple "
85
+ "sur un jeu de données privé Hugging Face, puis de tester les modèles sauvegardés "
86
+ "sur une image importée ou sur un échantillon aléatoire."
87
  )
88
 
89
  with gr.Tabs():
90
+ with gr.Tab("Entraîner"):
91
  with gr.Row():
92
  with gr.Column():
93
+ gr.Markdown("### Paramètres d’entraînement")
94
+
95
+ conv1_channels = gr.Slider(
96
+ 8, 64, value=16, step=8, label="Nombre de canaux - couche convolutionnelle 1"
97
+ )
98
+ conv2_channels = gr.Slider(
99
+ 16, 128, value=32, step=16, label="Nombre de canaux - couche convolutionnelle 2"
100
  )
101
+ kernel_size = gr.Dropdown(
102
+ choices=[3, 5], value=3, label="Taille du noyau"
103
+ )
104
+ dropout = gr.Slider(
105
+ 0.0, 0.7, value=0.2, step=0.05, label="Dropout"
106
+ )
107
+ fc_dim = gr.Slider(
108
+ 32, 256, value=128, step=32, label="Dimension de la couche cachée fully-connected"
109
+ )
110
+ learning_rate = gr.Number(
111
+ value=0.001, label="Taux d’apprentissage"
112
+ )
113
+ batch_size = gr.Dropdown(
114
+ choices=[16, 32, 64, 128], value=32, label="Taille du batch"
115
+ )
116
+ epochs = gr.Slider(
117
+ 1, 20, value=5, step=1, label="Nombre d’époques"
118
+ )
119
+ model_tag = gr.Textbox(
120
+ label="Nom court du modèle",
121
+ placeholder="ex. charbon_cnn_test"
122
+ )
123
+
124
+ train_btn = gr.Button("Lancer l’entraînement", variant="primary")
125
 
126
  with gr.Column():
127
+ train_status = gr.Textbox(label="Journal d’entraînement", lines=18)
128
+ train_history = gr.JSON(label="Historique d’entraînement")
129
+ train_summary = gr.JSON(label="Résumé d’entraînement")
130
 
131
+ with gr.Tab("Tester"):
132
  with gr.Row():
133
  with gr.Column():
134
+ gr.Markdown("### Modèle sauvegardé")
135
+
136
  model_selector = gr.Dropdown(
137
  choices=initial_models,
138
  value=initial_models[0] if initial_models else None,
139
+ label="Sélectionner un modèle",
140
  )
141
+ refresh_btn = gr.Button("Actualiser la liste des modèles")
142
+ load_info_btn = gr.Button("Afficher les informations du modèle")
143
+ model_info = gr.JSON(label="Métadonnées du modèle")
144
 
145
  with gr.Column():
146
+ gr.Markdown("### Prédiction sur une image importée")
147
+
148
+ upload_image = gr.Image(type="pil", label="Importer une image")
149
+ predict_btn = gr.Button("Prédire la classe", variant="primary")
150
+ predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7)
151
+ predict_probs = gr.Label(label="Probabilités par classe")
152
 
153
  with gr.Row():
154
+ random_test_btn = gr.Button("Tester un échantillon aléatoire")
155
 
156
  with gr.Row():
157
+ random_sample_image = gr.Image(type="pil", label="Image test aléatoire")
158
+ random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7)
159
+ random_sample_probs = gr.Label(label="Probabilités par classe")
160
 
161
  train_btn.click(
162
  fn=train_callback,
163
  inputs=[
 
164
  conv1_channels,
165
  conv2_channels,
166
  kernel_size,