import gradio as gr import os import shutil import zipfile from PIL import Image import torch import evaluate import numpy as np from datasets import load_dataset from transformers import ( AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, ) # ===================================================================== # CONFIG GLOBALE # ===================================================================== DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224" TRAIN_OUTPUT_DIR = "trained_model" DATASET_EXTRACT_DIR = "dataset_extracted" MODEL_UPLOAD_DIR = "uploaded_model" # ===================================================================== # OUTILS : EXTRACTION ZIP # ===================================================================== def extract_zip(zip_path, dest_dir): """Extrait un ZIP et remplace le dossier existant.""" if os.path.isdir(dest_dir): shutil.rmtree(dest_dir) os.makedirs(dest_dir, exist_ok=True) try: with zipfile.ZipFile(zip_path, 'r') as zf: zf.extractall(dest_dir) return True, None except Exception as e: return False, f"Erreur extraction ZIP : {e}" def find_true_dataset_root(root): """ Trouve automatiquement le vrai dossier contenant les classes : - Bonne/ - Mauvaise/ Même si le ZIP contient une couche inutile : dataset.zip images de soudures/ bonne/ mauvaise/ Cette fonction retourne le dossier qui contient réellement les classes. """ subdirs = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))] # Cas idéal : les classes sont directement présentes if any(name.lower() in ["bonne", "mauvaise"] for name in subdirs): return root # Sinon, entrer dans le premier sous-dossier if len(subdirs) == 1: sub = os.path.join(root, subdirs[0]) deeper = [d for d in os.listdir(sub) if os.path.isdir(os.path.join(sub, d))] if any(name.lower() in ["bonne", "mauvaise"] for name in deeper): return sub return root # fallback # ===================================================================== # PAGE 1 : ENTRAÎNEMENT # ===================================================================== def train_model(zip_dataset_path, model_name, epochs, batch_size, lr): if zip_dataset_path is None: return "Erreur : aucun dataset ZIP fourni." # 1) Extraire le ZIP success, err = extract_zip(zip_dataset_path, DATASET_EXTRACT_DIR) if not success: return err # 2) Trouver le vrai dossier racine du dataset true_root = find_true_dataset_root(DATASET_EXTRACT_DIR) # 3) Charger dataset HF try: dataset = load_dataset("imagefolder", data_dir=true_root) except Exception as e: return f"Erreur lors du chargement du dataset imagefolder : {e}" # Afficher colonnes détectées column_info = f"Colonnes détectées : {dataset['train'].column_names}\n" feature_info = f"Features : {dataset['train'].features}\n" debug_log = column_info + feature_info # Vérifier que la colonne label existe if "label" not in dataset["train"].column_names: return debug_log + "\nErreur : aucune colonne label détectée." label_names = dataset["train"].features["label"].names num_labels = len(label_names) # 4) Préprocesseur processor = AutoImageProcessor.from_pretrained(model_name) # Détecter la colonne image réellement présente def detect_image_key(keys): if "image" in keys: return "image" if "file" in keys: return "file" if "path" in keys: return "path" # fallback: première colonne non-label for k in keys: if k != "label": return k raise KeyError(f"Aucune colonne image trouvée dans {keys}") image_key = detect_image_key(dataset["train"].column_names) # 5) Transformation robuste def transform(batch): raw_imgs = batch[image_key] pil_images = [] for elem in raw_imgs: if isinstance(elem, Image.Image): pil_images.append(elem.convert("RGB")) else: pil_images.append(Image.open(elem).convert("RGB")) inputs = processor(pil_images, return_tensors="pt") inputs["labels"] = batch["label"] return inputs dataset = dataset.with_transform(transform) # 6) Charger modèle pré-entraîné model = AutoModelForImageClassification.from_pretrained( model_name, num_labels=num_labels, id2label={i: n for i, n in enumerate(label_names)}, label2id={n: i for i, n in enumerate(label_names)}, ignore_mismatched_sizes=True, # indispensable pour adapter 1000 → 2 classes ) # 7) Métrique metric = evaluate.load("accuracy") def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) return metric.compute(predictions=preds, references=labels) # 8) TrainingArguments args = TrainingArguments( output_dir=TRAIN_OUTPUT_DIR, num_train_epochs=int(epochs), per_device_train_batch_size=int(batch_size), per_device_eval_batch_size=int(batch_size), learning_rate=float(lr), evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="accuracy", report_to=[], ) # 9) Trainer trainer = Trainer( model=model, args=args, train_dataset=dataset["train"], eval_dataset=dataset.get("validation", dataset["train"]), compute_metrics=compute_metrics, ) trainer.train() # 10) Sauvegarde finale model.save_pretrained(TRAIN_OUTPUT_DIR) processor.save_pretrained(TRAIN_OUTPUT_DIR) return debug_log + f"\nEntraînement terminé. Modèle sauvegardé dans : {TRAIN_OUTPUT_DIR}" # ===================================================================== # PAGE 2 : INFÉRENCE # ===================================================================== def extract_model(zip_model_path): if os.path.isdir(MODEL_UPLOAD_DIR): shutil.rmtree(MODEL_UPLOAD_DIR) os.makedirs(MODEL_UPLOAD_DIR, exist_ok=True) try: with zipfile.ZipFile(zip_model_path, 'r') as zf: zf.extractall(MODEL_UPLOAD_DIR) return True, None except Exception as e: return False, f"Erreur extraction modèle : {e}" def predict(model_zip_path, image): if model_zip_path is None: return "Erreur : aucun modèle ZIP fourni." success, err = extract_model(model_zip_path) if not success: return err try: model = AutoModelForImageClassification.from_pretrained(MODEL_UPLOAD_DIR) processor = AutoImageProcessor.from_pretrained(MODEL_UPLOAD_DIR) except Exception as e: return f"Erreur lors du chargement du modèle : {e}" if not isinstance(image, Image.Image): image = Image.fromarray(image) inputs = processor(images=image.convert("RGB"), return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy() id2label = model.config.id2label return {id2label[i]: float(probs[i]) for i in range(len(probs))} # ===================================================================== # INTERFACE GRADIO # ===================================================================== with gr.Blocks() as demo: gr.Markdown("# Classification de Soudures — Entraînement & Inférence") # ------------------------------------------------------------- # ONGLET 1 : ENTRAÎNEMENT # ------------------------------------------------------------- with gr.Tab("1 • Entraîner un modèle"): dataset_zip = gr.File(label="Dataset ZIP (Bonne/ et Mauvaise/)", type="filepath") model_name = gr.Textbox(label="Modèle de départ", value=DEFAULT_MODEL_NAME) epochs = gr.Slider(label="Époques", minimum=1, maximum=50, value=5) batch = gr.Slider(label="Batch size", minimum=2, maximum=64, value=8) lr = gr.Number(label="Learning rate", value=5e-5) train_btn = gr.Button("Lancer l'entraînement") train_log = gr.Textbox(label="Logs", lines=10) train_btn.click( train_model, inputs=[dataset_zip, model_name, epochs, batch, lr], outputs=train_log ) # ------------------------------------------------------------- # ONGLET 2 : INFÉRENCE # ------------------------------------------------------------- with gr.Tab("2 • Tester un modèle"): model_zip = gr.File(label="Modèle ZIP", type="filepath") input_image = gr.Image(label="Image de soudure") predict_btn = gr.Button("Prédire") result = gr.Label(label="Résultat") predict_btn.click( predict, inputs=[model_zip, input_image], outputs=result ) if __name__ == "__main__": demo.launch()