Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import os | |
| import shutil | |
| import zipfile | |
| import evaluate # pour les métriques | |
| DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224" # ou "google/vit-base-patch16-224" | |
| DEFAULT_OUTPUT_DIR = "./weld_cls_model_best" | |
| EXTRACT_DIR = "./uploaded_dataset" # dossier où l'on extrait l'archive | |
| def extract_archive(archive_path, extract_to=EXTRACT_DIR): | |
| """ | |
| Extrait une archive .zip ou .rar dans extract_to. | |
| La structure attendue après extraction est de type imagefolder : | |
| extract_to/ | |
| bonne/ | |
| img1.jpg | |
| ... | |
| mauvaise/ | |
| img2.jpg | |
| ... | |
| """ | |
| if archive_path is None or not os.path.isfile(archive_path): | |
| return None, f"Erreur : aucune archive de dataset fournie." | |
| # Nettoyer l'ancien dossier, s'il existe | |
| if os.path.isdir(extract_to): | |
| shutil.rmtree(extract_to) | |
| os.makedirs(extract_to, exist_ok=True) | |
| archive_lower = archive_path.lower() | |
| try: | |
| if archive_lower.endswith(".zip"): | |
| with zipfile.ZipFile(archive_path, "r") as zf: | |
| zf.extractall(extract_to) | |
| elif archive_lower.endswith(".rar"): | |
| try: | |
| import rarfile | |
| except ImportError: | |
| return None, ( | |
| "Erreur : format .rar demandé mais le module 'rarfile' n'est pas installé.\n" | |
| "Ajoute 'rarfile' dans requirements.txt, ou utilise une archive .zip." | |
| ) | |
| with rarfile.RarFile(archive_path) as rf: | |
| rf.extractall(extract_to) | |
| else: | |
| return None, "Erreur : format d'archive non supporté. Utilise .zip ou .rar." | |
| except Exception as e: | |
| return None, f"Erreur lors de l'extraction de l'archive : {e}" | |
| return extract_to, None | |
| def train_model(dataset_archive_path, model_name=DEFAULT_MODEL_NAME, num_epochs=10, batch_size=16, lr=5e-5): | |
| """ | |
| Lance l'entraînement à partir d'une archive uploadée (zip/rar) contenant | |
| un dataset de type imagefolder. | |
| """ | |
| # 0) Extraction de l'archive | |
| data_dir, err = extract_archive(dataset_archive_path) | |
| if err is not None: | |
| return err | |
| if data_dir is None or not os.path.isdir(data_dir): | |
| return f"Erreur : le dossier de données '{data_dir}' est introuvable après extraction." | |
| # 1) Charger le dataset | |
| try: | |
| dataset = load_dataset("imagefolder", data_dir=data_dir) | |
| except Exception as e: | |
| return f"Erreur lors du chargement du dataset avec 'imagefolder' : {e}" | |
| label_names = dataset["train"].features["label"].names | |
| num_labels = len(label_names) | |
| # 2) Choisir le modèle HF | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| def transform(example_batch): | |
| images = [ | |
| x.convert("RGB") if isinstance(x, Image.Image) else x | |
| for x in example_batch["image"] | |
| ] | |
| inputs = processor(images, return_tensors="pt") | |
| inputs["labels"] = example_batch["label"] | |
| return inputs | |
| prepared_ds = dataset.with_transform(transform) | |
| # 3) Charger le modèle de classification | |
| model = AutoModelForImageClassification.from_pretrained( | |
| model_name, | |
| num_labels=num_labels, | |
| id2label={i: l for i, l in enumerate(label_names)}, | |
| label2id={l: i for i, l in enumerate(label_names)}, | |
| ) | |
| # 4) Définir les métriques | |
| metric = evaluate.load("accuracy") | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| return metric.compute(predictions=preds, references=labels) | |
| # 5) Définir le Trainer | |
| training_args = TrainingArguments( | |
| output_dir="./weld_cls_model", | |
| per_device_train_batch_size=int(batch_size), | |
| per_device_eval_batch_size=int(batch_size), | |
| learning_rate=float(lr), | |
| num_train_epochs=int(num_epochs), | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| logging_steps=50, | |
| report_to=[], # désactive Weights & Biases dans un Space si non configuré | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=prepared_ds["train"], | |
| eval_dataset=prepared_ds.get("validation", prepared_ds["train"]), | |
| compute_metrics=compute_metrics, | |
| ) | |
| # 6) Entraînement | |
| trainer.train() | |
| # 7) Sauvegarde | |
| trainer.save_model(DEFAULT_OUTPUT_DIR) | |
| processor.save_pretrained(DEFAULT_OUTPUT_DIR) | |
| return f"Entraînement terminé. Modèle sauvegardé dans : {DEFAULT_OUTPUT_DIR}" | |
| def predict(image): | |
| """ | |
| Inférence : prend une image et renvoie les probabilités par classe. | |
| Utilise le modèle sauvegardé dans DEFAULT_OUTPUT_DIR. | |
| """ | |
| if not os.path.isdir(DEFAULT_OUTPUT_DIR): | |
| return "Erreur : aucun modèle entraîné trouvé. Lance d'abord l'entraînement." | |
| # Charger modèle + processor | |
| model = AutoModelForImageClassification.from_pretrained(DEFAULT_OUTPUT_DIR) | |
| processor = AutoImageProcessor.from_pretrained(DEFAULT_OUTPUT_DIR) | |
| model.eval() | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| inputs = processor(images=image.convert("RGB"), return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = outputs.logits.softmax(dim=-1)[0].cpu().numpy() | |
| id2label = model.config.id2label | |
| result = {id2label[i]: float(probs[i]) for i in range(len(probs))} | |
| return result | |
| # ----------------------- | |
| # Interface Gradio | |
| # ----------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Classification de soudures – Entraînement + Inférence\n" | |
| "Upload d'un dataset (.zip ou .rar), entraînement du modèle, puis test sur des images.") | |
| with gr.Tab("Entraînement"): | |
| gr.Markdown("## Lancer l'entraînement") | |
| dataset_file_input = gr.File( | |
| label="Archive du dataset (.zip ou .rar)", | |
| type="filepath" | |
| ) | |
| model_name_input = gr.Textbox( | |
| label="Nom du modèle Hugging Face", | |
| value=DEFAULT_MODEL_NAME | |
| ) | |
| epochs_input = gr.Slider( | |
| label="Nombre d'époques", | |
| minimum=1, | |
| maximum=50, | |
| value=10, | |
| step=1 | |
| ) | |
| batch_input = gr.Slider( | |
| label="Batch size", | |
| minimum=4, | |
| maximum=64, | |
| value=16, | |
| step=4 | |
| ) | |
| lr_input = gr.Number( | |
| label="Learning rate", | |
| value=5e-5, | |
| precision=7 | |
| ) | |
| train_button = gr.Button("Lancer l'entraînement") | |
| train_output = gr.Textbox(label="Journal / Résultat de l'entraînement") | |
| train_button.click( | |
| fn=train_model, | |
| inputs=[dataset_file_input, model_name_input, epochs_input, batch_input, lr_input], | |
| outputs=train_output | |
| ) | |
| with gr.Tab("Inférence"): | |
| gr.Markdown("## Tester le modèle entraîné sur une image de soudure") | |
| image_input = gr.Image(label="Image de soudure", type="pil") | |
| pred_button = gr.Button("Prédire") | |
| pred_output = gr.Label(label="Probabilités par classe") | |
| pred_button.click( | |
| fn=predict, | |
| inputs=image_input, | |
| outputs=pred_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |