from datasets import load_dataset from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer from PIL import Image import numpy as np import gradio as gr import torch import os import shutil import zipfile import evaluate # pour les métriques DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224" # ou "google/vit-base-patch16-224" DEFAULT_OUTPUT_DIR = "./weld_cls_model_best" EXTRACT_DIR = "./uploaded_dataset" # dossier où l'on extrait l'archive def extract_archive(archive_path, extract_to=EXTRACT_DIR): """ Extrait une archive .zip ou .rar dans extract_to. La structure attendue après extraction est de type imagefolder : extract_to/ bonne/ img1.jpg ... mauvaise/ img2.jpg ... """ if archive_path is None or not os.path.isfile(archive_path): return None, f"Erreur : aucune archive de dataset fournie." # Nettoyer l'ancien dossier, s'il existe if os.path.isdir(extract_to): shutil.rmtree(extract_to) os.makedirs(extract_to, exist_ok=True) archive_lower = archive_path.lower() try: if archive_lower.endswith(".zip"): with zipfile.ZipFile(archive_path, "r") as zf: zf.extractall(extract_to) elif archive_lower.endswith(".rar"): try: import rarfile except ImportError: return None, ( "Erreur : format .rar demandé mais le module 'rarfile' n'est pas installé.\n" "Ajoute 'rarfile' dans requirements.txt, ou utilise une archive .zip." ) with rarfile.RarFile(archive_path) as rf: rf.extractall(extract_to) else: return None, "Erreur : format d'archive non supporté. Utilise .zip ou .rar." except Exception as e: return None, f"Erreur lors de l'extraction de l'archive : {e}" return extract_to, None def train_model(dataset_archive_path, model_name=DEFAULT_MODEL_NAME, num_epochs=10, batch_size=16, lr=5e-5): """ Lance l'entraînement à partir d'une archive uploadée (zip/rar) contenant un dataset de type imagefolder. """ # 0) Extraction de l'archive data_dir, err = extract_archive(dataset_archive_path) if err is not None: return err if data_dir is None or not os.path.isdir(data_dir): return f"Erreur : le dossier de données '{data_dir}' est introuvable après extraction." # 1) Charger le dataset try: dataset = load_dataset("imagefolder", data_dir=data_dir) except Exception as e: return f"Erreur lors du chargement du dataset avec 'imagefolder' : {e}" label_names = dataset["train"].features["label"].names num_labels = len(label_names) # 2) Choisir le modèle HF processor = AutoImageProcessor.from_pretrained(model_name) def transform(example_batch): images = [ x.convert("RGB") if isinstance(x, Image.Image) else x for x in example_batch["image"] ] inputs = processor(images, return_tensors="pt") inputs["labels"] = example_batch["label"] return inputs prepared_ds = dataset.with_transform(transform) # 3) Charger le modèle de classification model = AutoModelForImageClassification.from_pretrained( model_name, num_labels=num_labels, id2label={i: l for i, l in enumerate(label_names)}, label2id={l: i for i, l in enumerate(label_names)}, ) # 4) Définir les métriques metric = evaluate.load("accuracy") def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) return metric.compute(predictions=preds, references=labels) # 5) Définir le Trainer training_args = TrainingArguments( output_dir="./weld_cls_model", per_device_train_batch_size=int(batch_size), per_device_eval_batch_size=int(batch_size), learning_rate=float(lr), num_train_epochs=int(num_epochs), evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="accuracy", logging_steps=50, report_to=[], # désactive Weights & Biases dans un Space si non configuré ) trainer = Trainer( model=model, args=training_args, train_dataset=prepared_ds["train"], eval_dataset=prepared_ds.get("validation", prepared_ds["train"]), compute_metrics=compute_metrics, ) # 6) Entraînement trainer.train() # 7) Sauvegarde trainer.save_model(DEFAULT_OUTPUT_DIR) processor.save_pretrained(DEFAULT_OUTPUT_DIR) return f"Entraînement terminé. Modèle sauvegardé dans : {DEFAULT_OUTPUT_DIR}" def predict(image): """ Inférence : prend une image et renvoie les probabilités par classe. Utilise le modèle sauvegardé dans DEFAULT_OUTPUT_DIR. """ if not os.path.isdir(DEFAULT_OUTPUT_DIR): return "Erreur : aucun modèle entraîné trouvé. Lance d'abord l'entraînement." # Charger modèle + processor model = AutoModelForImageClassification.from_pretrained(DEFAULT_OUTPUT_DIR) processor = AutoImageProcessor.from_pretrained(DEFAULT_OUTPUT_DIR) model.eval() if not isinstance(image, Image.Image): image = Image.fromarray(image) inputs = processor(images=image.convert("RGB"), return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits.softmax(dim=-1)[0].cpu().numpy() id2label = model.config.id2label result = {id2label[i]: float(probs[i]) for i in range(len(probs))} return result # ----------------------- # Interface Gradio # ----------------------- with gr.Blocks() as demo: gr.Markdown("# Classification de soudures – Entraînement + Inférence\n" "Upload d'un dataset (.zip ou .rar), entraînement du modèle, puis test sur des images.") with gr.Tab("Entraînement"): gr.Markdown("## Lancer l'entraînement") dataset_file_input = gr.File( label="Archive du dataset (.zip ou .rar)", type="filepath" ) model_name_input = gr.Textbox( label="Nom du modèle Hugging Face", value=DEFAULT_MODEL_NAME ) epochs_input = gr.Slider( label="Nombre d'époques", minimum=1, maximum=50, value=10, step=1 ) batch_input = gr.Slider( label="Batch size", minimum=4, maximum=64, value=16, step=4 ) lr_input = gr.Number( label="Learning rate", value=5e-5, precision=7 ) train_button = gr.Button("Lancer l'entraînement") train_output = gr.Textbox(label="Journal / Résultat de l'entraînement") train_button.click( fn=train_model, inputs=[dataset_file_input, model_name_input, epochs_input, batch_input, lr_input], outputs=train_output ) with gr.Tab("Inférence"): gr.Markdown("## Tester le modèle entraîné sur une image de soudure") image_input = gr.Image(label="Image de soudure", type="pil") pred_button = gr.Button("Prédire") pred_output = gr.Label(label="Probabilités par classe") pred_button.click( fn=predict, inputs=image_input, outputs=pred_output ) if __name__ == "__main__": demo.launch()