ESMIEU Nathan OBS/OBF commited on
Commit
a869dd5
·
1 Parent(s): 464a4c3

Update app.py and training logic

Browse files
Files changed (1) hide show
  1. app.py +68 -20
app.py CHANGED
@@ -1,32 +1,81 @@
1
- from datasets import load_dataset, load_metric
2
  from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
3
  from PIL import Image
4
  import numpy as np
5
  import gradio as gr
6
  import torch
7
  import os
 
 
 
8
 
9
  DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224" # ou "google/vit-base-patch16-224"
10
  DEFAULT_OUTPUT_DIR = "./weld_cls_model_best"
 
11
 
12
 
13
- def train_model(data_dir, model_name=DEFAULT_MODEL_NAME, num_epochs=10, batch_size=16, lr=5e-5):
14
  """
15
- Lance l'entraînement sur un dataset d'images de type imagefolder :
16
- data_dir/
17
- bonne/
18
- img1.jpg
19
- ...
20
- mauvaise/
21
- img2.jpg
22
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  """
24
 
25
- if not os.path.isdir(data_dir):
26
- return f"Erreur : le dossier {data_dir} n'existe pas."
 
 
 
 
27
 
28
  # 1) Charger le dataset
29
- dataset = load_dataset("imagefolder", data_dir=data_dir)
 
 
 
30
 
31
  label_names = dataset["train"].features["label"].names
32
  num_labels = len(label_names)
@@ -54,7 +103,7 @@ def train_model(data_dir, model_name=DEFAULT_MODEL_NAME, num_epochs=10, batch_si
54
  )
55
 
56
  # 4) Définir les métriques
57
- metric = load_metric("accuracy")
58
 
59
  def compute_metrics(eval_pred):
60
  logits, labels = eval_pred
@@ -128,15 +177,14 @@ def predict(image):
128
  # -----------------------
129
  with gr.Blocks() as demo:
130
  gr.Markdown("# Classification de soudures – Entraînement + Inférence\n"
131
- "Interface Gradio pour entraîner un modèle Hugging Face et tester des images.")
132
 
133
  with gr.Tab("Entraînement"):
134
  gr.Markdown("## Lancer l'entraînement")
135
 
136
- data_dir_input = gr.Textbox(
137
- label="Dossier du dataset (format imagefolder)",
138
- value="path/vers/tes_images",
139
- placeholder="Ex : /data/soudures"
140
  )
141
  model_name_input = gr.Textbox(
142
  label="Nom du modèle Hugging Face",
@@ -167,7 +215,7 @@ with gr.Blocks() as demo:
167
 
168
  train_button.click(
169
  fn=train_model,
170
- inputs=[data_dir_input, model_name_input, epochs_input, batch_input, lr_input],
171
  outputs=train_output
172
  )
173
 
 
1
+ from datasets import load_dataset
2
  from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
3
  from PIL import Image
4
  import numpy as np
5
  import gradio as gr
6
  import torch
7
  import os
8
+ import shutil
9
+ import zipfile
10
+ import evaluate # pour les métriques
11
 
12
  DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224" # ou "google/vit-base-patch16-224"
13
  DEFAULT_OUTPUT_DIR = "./weld_cls_model_best"
14
+ EXTRACT_DIR = "./uploaded_dataset" # dossier où l'on extrait l'archive
15
 
16
 
17
+ def extract_archive(archive_path, extract_to=EXTRACT_DIR):
18
  """
19
+ Extrait une archive .zip ou .rar dans extract_to.
20
+ La structure attendue après extraction est de type imagefolder :
21
+ extract_to/
22
+ bonne/
23
+ img1.jpg
24
+ ...
25
+ mauvaise/
26
+ img2.jpg
27
+ ...
28
+ """
29
+ if archive_path is None or not os.path.isfile(archive_path):
30
+ return None, f"Erreur : aucune archive de dataset fournie."
31
+
32
+ # Nettoyer l'ancien dossier, s'il existe
33
+ if os.path.isdir(extract_to):
34
+ shutil.rmtree(extract_to)
35
+ os.makedirs(extract_to, exist_ok=True)
36
+
37
+ archive_lower = archive_path.lower()
38
+
39
+ try:
40
+ if archive_lower.endswith(".zip"):
41
+ with zipfile.ZipFile(archive_path, "r") as zf:
42
+ zf.extractall(extract_to)
43
+ elif archive_lower.endswith(".rar"):
44
+ try:
45
+ import rarfile
46
+ except ImportError:
47
+ return None, (
48
+ "Erreur : format .rar demandé mais le module 'rarfile' n'est pas installé.\n"
49
+ "Ajoute 'rarfile' dans requirements.txt, ou utilise une archive .zip."
50
+ )
51
+ with rarfile.RarFile(archive_path) as rf:
52
+ rf.extractall(extract_to)
53
+ else:
54
+ return None, "Erreur : format d'archive non supporté. Utilise .zip ou .rar."
55
+ except Exception as e:
56
+ return None, f"Erreur lors de l'extraction de l'archive : {e}"
57
+
58
+ return extract_to, None
59
+
60
+
61
+ def train_model(dataset_archive_path, model_name=DEFAULT_MODEL_NAME, num_epochs=10, batch_size=16, lr=5e-5):
62
+ """
63
+ Lance l'entraînement à partir d'une archive uploadée (zip/rar) contenant
64
+ un dataset de type imagefolder.
65
  """
66
 
67
+ # 0) Extraction de l'archive
68
+ data_dir, err = extract_archive(dataset_archive_path)
69
+ if err is not None:
70
+ return err
71
+ if data_dir is None or not os.path.isdir(data_dir):
72
+ return f"Erreur : le dossier de données '{data_dir}' est introuvable après extraction."
73
 
74
  # 1) Charger le dataset
75
+ try:
76
+ dataset = load_dataset("imagefolder", data_dir=data_dir)
77
+ except Exception as e:
78
+ return f"Erreur lors du chargement du dataset avec 'imagefolder' : {e}"
79
 
80
  label_names = dataset["train"].features["label"].names
81
  num_labels = len(label_names)
 
103
  )
104
 
105
  # 4) Définir les métriques
106
+ metric = evaluate.load("accuracy")
107
 
108
  def compute_metrics(eval_pred):
109
  logits, labels = eval_pred
 
177
  # -----------------------
178
  with gr.Blocks() as demo:
179
  gr.Markdown("# Classification de soudures – Entraînement + Inférence\n"
180
+ "Upload d'un dataset (.zip ou .rar), entraînement du modèle, puis test sur des images.")
181
 
182
  with gr.Tab("Entraînement"):
183
  gr.Markdown("## Lancer l'entraînement")
184
 
185
+ dataset_file_input = gr.File(
186
+ label="Archive du dataset (.zip ou .rar)",
187
+ type="filepath"
 
188
  )
189
  model_name_input = gr.Textbox(
190
  label="Nom du modèle Hugging Face",
 
215
 
216
  train_button.click(
217
  fn=train_model,
218
+ inputs=[dataset_file_input, model_name_input, epochs_input, batch_input, lr_input],
219
  outputs=train_output
220
  )
221