ESMIEU Nathan OBS/OBF commited on
Commit
4362883
·
1 Parent(s): af0d173

update app.py

Browse files
Files changed (2) hide show
  1. app.py +95 -44
  2. dataset.zip +3 -0
app.py CHANGED
@@ -17,7 +17,7 @@ from transformers import (
17
 
18
 
19
  # =====================================================================
20
- # CONFIGURATION GLOBALE
21
  # =====================================================================
22
 
23
  DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224"
@@ -31,11 +31,9 @@ MODEL_UPLOAD_DIR = "uploaded_model"
31
  # =====================================================================
32
 
33
  def extract_zip(zip_path, dest_dir):
34
- """Extrait un fichier ZIP dans dest_dir (remplace si existe)."""
35
-
36
  if os.path.isdir(dest_dir):
37
  shutil.rmtree(dest_dir)
38
-
39
  os.makedirs(dest_dir, exist_ok=True)
40
 
41
  try:
@@ -43,59 +41,120 @@ def extract_zip(zip_path, dest_dir):
43
  zf.extractall(dest_dir)
44
  return True, None
45
  except Exception as e:
46
- return False, f"Erreur lors de l'extraction : {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  # =====================================================================
50
- # PAGE 1 : ENTRAÎNEMENT DU MODÈLE
51
  # =====================================================================
52
 
53
  def train_model(zip_dataset_path, model_name, epochs, batch_size, lr):
54
- # 1) Vérifier le dataset
55
  if zip_dataset_path is None:
56
- return "Erreur : aucun ZIP de dataset fourni."
57
 
 
58
  success, err = extract_zip(zip_dataset_path, DATASET_EXTRACT_DIR)
59
  if not success:
60
  return err
61
 
62
- # Structure attendue :
63
- # DATASET_EXTRACT_DIR/
64
- # Bonne/
65
- # Mauvaise/
66
 
 
67
  try:
68
- dataset = load_dataset("imagefolder", data_dir=DATASET_EXTRACT_DIR)
69
  except Exception as e:
70
- return f"Erreur : impossible de charger le dataset au format imagefolder : {e}"
71
 
 
 
 
 
 
 
72
  if "label" not in dataset["train"].column_names:
73
- return f"Erreur : le dataset ne contient pas de colonne 'label'. Structure attendue : ZIP → Bonne/ et Mauvaise/."
74
 
75
  label_names = dataset["train"].features["label"].names
76
  num_labels = len(label_names)
77
 
78
- # 2) Préprocessing
79
  processor = AutoImageProcessor.from_pretrained(model_name)
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def transform(batch):
82
- images = [img.convert("RGB") for img in batch["image"]]
83
- inputs = processor(images, return_tensors="pt")
 
 
 
 
 
 
 
 
84
  inputs["labels"] = batch["label"]
85
  return inputs
86
 
87
  dataset = dataset.with_transform(transform)
88
 
89
- # 3) Charger modèle
90
  model = AutoModelForImageClassification.from_pretrained(
91
  model_name,
92
  num_labels=num_labels,
93
  id2label={i: n for i, n in enumerate(label_names)},
94
  label2id={n: i for i, n in enumerate(label_names)},
95
- ignore_mismatched_sizes=True, # IMPORTANT POUR ADAPTER 1000 → 2 CLASSES
96
  )
97
 
98
- # 4) Métriques
99
  metric = evaluate.load("accuracy")
100
 
101
  def compute_metrics(eval_pred):
@@ -103,7 +162,7 @@ def train_model(zip_dataset_path, model_name, epochs, batch_size, lr):
103
  preds = np.argmax(logits, axis=-1)
104
  return metric.compute(predictions=preds, references=labels)
105
 
106
- # 5) TrainingArguments
107
  args = TrainingArguments(
108
  output_dir=TRAIN_OUTPUT_DIR,
109
  num_train_epochs=int(epochs),
@@ -117,7 +176,7 @@ def train_model(zip_dataset_path, model_name, epochs, batch_size, lr):
117
  report_to=[],
118
  )
119
 
120
- # 6) Trainer
121
  trainer = Trainer(
122
  model=model,
123
  args=args,
@@ -128,11 +187,11 @@ def train_model(zip_dataset_path, model_name, epochs, batch_size, lr):
128
 
129
  trainer.train()
130
 
131
- # 7) Sauvegarde finale
132
  model.save_pretrained(TRAIN_OUTPUT_DIR)
133
  processor.save_pretrained(TRAIN_OUTPUT_DIR)
134
 
135
- return f"Entraînement terminé. Modèle sauvegardé dans : {TRAIN_OUTPUT_DIR}"
136
 
137
 
138
  # =====================================================================
@@ -140,10 +199,8 @@ def train_model(zip_dataset_path, model_name, epochs, batch_size, lr):
140
  # =====================================================================
141
 
142
  def extract_model(zip_model_path):
143
- """Dézippe un modèle téléchargé côté utilisateur."""
144
  if os.path.isdir(MODEL_UPLOAD_DIR):
145
  shutil.rmtree(MODEL_UPLOAD_DIR)
146
-
147
  os.makedirs(MODEL_UPLOAD_DIR, exist_ok=True)
148
 
149
  try:
@@ -156,18 +213,17 @@ def extract_model(zip_model_path):
156
 
157
  def predict(model_zip_path, image):
158
  if model_zip_path is None:
159
- return "Erreur : aucun modèle importé."
160
 
161
  success, err = extract_model(model_zip_path)
162
  if not success:
163
  return err
164
 
165
- # Charger modèle
166
  try:
167
  model = AutoModelForImageClassification.from_pretrained(MODEL_UPLOAD_DIR)
168
  processor = AutoImageProcessor.from_pretrained(MODEL_UPLOAD_DIR)
169
  except Exception as e:
170
- return f"Erreur de chargement du modèle : {e}"
171
 
172
  if not isinstance(image, Image.Image):
173
  image = Image.fromarray(image)
@@ -190,20 +246,18 @@ with gr.Blocks() as demo:
190
 
191
  gr.Markdown("# Classification de Soudures — Entraînement & Inférence")
192
 
193
- # -----------------------------------------------------------------
194
  # ONGLET 1 : ENTRAÎNEMENT
195
- # -----------------------------------------------------------------
196
  with gr.Tab("1 • Entraîner un modèle"):
197
- gr.Markdown("### Importer un dataset ZIP au format : Bonne/ et Mauvaise/")
198
-
199
- dataset_zip = gr.File(label="Dataset .zip", type="filepath")
200
  model_name = gr.Textbox(label="Modèle de départ", value=DEFAULT_MODEL_NAME)
201
- epochs = gr.Slider(label="Nombre d'époques", minimum=1, maximum=50, value=5)
202
  batch = gr.Slider(label="Batch size", minimum=2, maximum=64, value=8)
203
  lr = gr.Number(label="Learning rate", value=5e-5)
204
 
205
  train_btn = gr.Button("Lancer l'entraînement")
206
- train_log = gr.Textbox(label="Logs", lines=8)
207
 
208
  train_btn.click(
209
  train_model,
@@ -211,16 +265,14 @@ with gr.Blocks() as demo:
211
  outputs=train_log
212
  )
213
 
214
- # -----------------------------------------------------------------
215
  # ONGLET 2 : INFÉRENCE
216
- # -----------------------------------------------------------------
217
  with gr.Tab("2 • Tester un modèle"):
218
- gr.Markdown("### Importer un modèle ZIP (export HF) + une image de soudure")
219
-
220
- model_zip = gr.File(label="Modèle .zip", type="filepath")
221
  input_image = gr.Image(label="Image de soudure")
222
  predict_btn = gr.Button("Prédire")
223
- result = gr.Label(label="Probabilités")
224
 
225
  predict_btn.click(
226
  predict,
@@ -229,6 +281,5 @@ with gr.Blocks() as demo:
229
  )
230
 
231
 
232
- # Lancer l’app dans un Space HF
233
  if __name__ == "__main__":
234
  demo.launch()
 
17
 
18
 
19
  # =====================================================================
20
+ # CONFIG GLOBALE
21
  # =====================================================================
22
 
23
  DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224"
 
31
  # =====================================================================
32
 
33
  def extract_zip(zip_path, dest_dir):
34
+ """Extrait un ZIP et remplace le dossier existant."""
 
35
  if os.path.isdir(dest_dir):
36
  shutil.rmtree(dest_dir)
 
37
  os.makedirs(dest_dir, exist_ok=True)
38
 
39
  try:
 
41
  zf.extractall(dest_dir)
42
  return True, None
43
  except Exception as e:
44
+ return False, f"Erreur extraction ZIP : {e}"
45
+
46
+
47
+ def find_true_dataset_root(root):
48
+ """
49
+ Trouve automatiquement le vrai dossier contenant les classes :
50
+ - Bonne/
51
+ - Mauvaise/
52
+
53
+ Même si le ZIP contient une couche inutile :
54
+ dataset.zip
55
+ images de soudures/
56
+ bonne/
57
+ mauvaise/
58
+
59
+ Cette fonction retourne le dossier qui contient réellement les classes.
60
+ """
61
+ subdirs = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
62
+ # Cas idéal : les classes sont directement présentes
63
+ if any(name.lower() in ["bonne", "mauvaise"] for name in subdirs):
64
+ return root
65
+
66
+ # Sinon, entrer dans le premier sous-dossier
67
+ if len(subdirs) == 1:
68
+ sub = os.path.join(root, subdirs[0])
69
+ deeper = [d for d in os.listdir(sub) if os.path.isdir(os.path.join(sub, d))]
70
+
71
+ if any(name.lower() in ["bonne", "mauvaise"] for name in deeper):
72
+ return sub
73
+
74
+ return root # fallback
75
 
76
 
77
  # =====================================================================
78
+ # PAGE 1 : ENTRAÎNEMENT
79
  # =====================================================================
80
 
81
  def train_model(zip_dataset_path, model_name, epochs, batch_size, lr):
82
+
83
  if zip_dataset_path is None:
84
+ return "Erreur : aucun dataset ZIP fourni."
85
 
86
+ # 1) Extraire le ZIP
87
  success, err = extract_zip(zip_dataset_path, DATASET_EXTRACT_DIR)
88
  if not success:
89
  return err
90
 
91
+ # 2) Trouver le vrai dossier racine du dataset
92
+ true_root = find_true_dataset_root(DATASET_EXTRACT_DIR)
 
 
93
 
94
+ # 3) Charger dataset HF
95
  try:
96
+ dataset = load_dataset("imagefolder", data_dir=true_root)
97
  except Exception as e:
98
+ return f"Erreur lors du chargement du dataset imagefolder : {e}"
99
 
100
+ # Afficher colonnes détectées
101
+ column_info = f"Colonnes détectées : {dataset['train'].column_names}\n"
102
+ feature_info = f"Features : {dataset['train'].features}\n"
103
+ debug_log = column_info + feature_info
104
+
105
+ # Vérifier que la colonne label existe
106
  if "label" not in dataset["train"].column_names:
107
+ return debug_log + "\nErreur : aucune colonne label détectée."
108
 
109
  label_names = dataset["train"].features["label"].names
110
  num_labels = len(label_names)
111
 
112
+ # 4) Préprocesseur
113
  processor = AutoImageProcessor.from_pretrained(model_name)
114
 
115
+ # Détecter la colonne image réellement présente
116
+ def detect_image_key(keys):
117
+ if "image" in keys:
118
+ return "image"
119
+ if "file" in keys:
120
+ return "file"
121
+ if "path" in keys:
122
+ return "path"
123
+ # fallback: première colonne non-label
124
+ for k in keys:
125
+ if k != "label":
126
+ return k
127
+ raise KeyError(f"Aucune colonne image trouvée dans {keys}")
128
+
129
+ image_key = detect_image_key(dataset["train"].column_names)
130
+
131
+ # 5) Transformation robuste
132
  def transform(batch):
133
+ raw_imgs = batch[image_key]
134
+ pil_images = []
135
+
136
+ for elem in raw_imgs:
137
+ if isinstance(elem, Image.Image):
138
+ pil_images.append(elem.convert("RGB"))
139
+ else:
140
+ pil_images.append(Image.open(elem).convert("RGB"))
141
+
142
+ inputs = processor(pil_images, return_tensors="pt")
143
  inputs["labels"] = batch["label"]
144
  return inputs
145
 
146
  dataset = dataset.with_transform(transform)
147
 
148
+ # 6) Charger modèle pré-entraîné
149
  model = AutoModelForImageClassification.from_pretrained(
150
  model_name,
151
  num_labels=num_labels,
152
  id2label={i: n for i, n in enumerate(label_names)},
153
  label2id={n: i for i, n in enumerate(label_names)},
154
+ ignore_mismatched_sizes=True, # indispensable pour adapter 1000 → 2 classes
155
  )
156
 
157
+ # 7) Métrique
158
  metric = evaluate.load("accuracy")
159
 
160
  def compute_metrics(eval_pred):
 
162
  preds = np.argmax(logits, axis=-1)
163
  return metric.compute(predictions=preds, references=labels)
164
 
165
+ # 8) TrainingArguments
166
  args = TrainingArguments(
167
  output_dir=TRAIN_OUTPUT_DIR,
168
  num_train_epochs=int(epochs),
 
176
  report_to=[],
177
  )
178
 
179
+ # 9) Trainer
180
  trainer = Trainer(
181
  model=model,
182
  args=args,
 
187
 
188
  trainer.train()
189
 
190
+ # 10) Sauvegarde finale
191
  model.save_pretrained(TRAIN_OUTPUT_DIR)
192
  processor.save_pretrained(TRAIN_OUTPUT_DIR)
193
 
194
+ return debug_log + f"\nEntraînement terminé. Modèle sauvegardé dans : {TRAIN_OUTPUT_DIR}"
195
 
196
 
197
  # =====================================================================
 
199
  # =====================================================================
200
 
201
  def extract_model(zip_model_path):
 
202
  if os.path.isdir(MODEL_UPLOAD_DIR):
203
  shutil.rmtree(MODEL_UPLOAD_DIR)
 
204
  os.makedirs(MODEL_UPLOAD_DIR, exist_ok=True)
205
 
206
  try:
 
213
 
214
  def predict(model_zip_path, image):
215
  if model_zip_path is None:
216
+ return "Erreur : aucun modèle ZIP fourni."
217
 
218
  success, err = extract_model(model_zip_path)
219
  if not success:
220
  return err
221
 
 
222
  try:
223
  model = AutoModelForImageClassification.from_pretrained(MODEL_UPLOAD_DIR)
224
  processor = AutoImageProcessor.from_pretrained(MODEL_UPLOAD_DIR)
225
  except Exception as e:
226
+ return f"Erreur lors du chargement du modèle : {e}"
227
 
228
  if not isinstance(image, Image.Image):
229
  image = Image.fromarray(image)
 
246
 
247
  gr.Markdown("# Classification de Soudures — Entraînement & Inférence")
248
 
249
+ # -------------------------------------------------------------
250
  # ONGLET 1 : ENTRAÎNEMENT
251
+ # -------------------------------------------------------------
252
  with gr.Tab("1 • Entraîner un modèle"):
253
+ dataset_zip = gr.File(label="Dataset ZIP (Bonne/ et Mauvaise/)", type="filepath")
 
 
254
  model_name = gr.Textbox(label="Modèle de départ", value=DEFAULT_MODEL_NAME)
255
+ epochs = gr.Slider(label="Époques", minimum=1, maximum=50, value=5)
256
  batch = gr.Slider(label="Batch size", minimum=2, maximum=64, value=8)
257
  lr = gr.Number(label="Learning rate", value=5e-5)
258
 
259
  train_btn = gr.Button("Lancer l'entraînement")
260
+ train_log = gr.Textbox(label="Logs", lines=10)
261
 
262
  train_btn.click(
263
  train_model,
 
265
  outputs=train_log
266
  )
267
 
268
+ # -------------------------------------------------------------
269
  # ONGLET 2 : INFÉRENCE
270
+ # -------------------------------------------------------------
271
  with gr.Tab("2 • Tester un modèle"):
272
+ model_zip = gr.File(label="Modèle ZIP", type="filepath")
 
 
273
  input_image = gr.Image(label="Image de soudure")
274
  predict_btn = gr.Button("Prédire")
275
+ result = gr.Label(label="Résultat")
276
 
277
  predict_btn.click(
278
  predict,
 
281
  )
282
 
283
 
 
284
  if __name__ == "__main__":
285
  demo.launch()
dataset.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7367892765884ad898bb791e05d4900fd9c76f38dcf56bc69542287853b9414c
3
+ size 488258521