Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import shutil | |
| import zipfile | |
| from PIL import Image | |
| import torch | |
| import evaluate | |
| import numpy as np | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoImageProcessor, | |
| AutoModelForImageClassification, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| # ===================================================================== | |
| # CONFIG GLOBALE | |
| # ===================================================================== | |
| DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224" | |
| TRAIN_OUTPUT_DIR = "trained_model" | |
| DATASET_EXTRACT_DIR = "dataset_extracted" | |
| MODEL_UPLOAD_DIR = "uploaded_model" | |
| # ===================================================================== | |
| # OUTILS : EXTRACTION ZIP | |
| # ===================================================================== | |
| def extract_zip(zip_path, dest_dir): | |
| """Extrait un ZIP et remplace le dossier existant.""" | |
| if os.path.isdir(dest_dir): | |
| shutil.rmtree(dest_dir) | |
| os.makedirs(dest_dir, exist_ok=True) | |
| try: | |
| with zipfile.ZipFile(zip_path, 'r') as zf: | |
| zf.extractall(dest_dir) | |
| return True, None | |
| except Exception as e: | |
| return False, f"Erreur extraction ZIP : {e}" | |
| def find_true_dataset_root(root): | |
| """ | |
| Trouve automatiquement le vrai dossier contenant les classes : | |
| - Bonne/ | |
| - Mauvaise/ | |
| Même si le ZIP contient une couche inutile : | |
| dataset.zip | |
| images de soudures/ | |
| bonne/ | |
| mauvaise/ | |
| Cette fonction retourne le dossier qui contient réellement les classes. | |
| """ | |
| subdirs = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))] | |
| # Cas idéal : les classes sont directement présentes | |
| if any(name.lower() in ["bonne", "mauvaise"] for name in subdirs): | |
| return root | |
| # Sinon, entrer dans le premier sous-dossier | |
| if len(subdirs) == 1: | |
| sub = os.path.join(root, subdirs[0]) | |
| deeper = [d for d in os.listdir(sub) if os.path.isdir(os.path.join(sub, d))] | |
| if any(name.lower() in ["bonne", "mauvaise"] for name in deeper): | |
| return sub | |
| return root # fallback | |
| # ===================================================================== | |
| # PAGE 1 : ENTRAÎNEMENT | |
| # ===================================================================== | |
| def train_model(zip_dataset_path, model_name, epochs, batch_size, lr): | |
| if zip_dataset_path is None: | |
| return "Erreur : aucun dataset ZIP fourni." | |
| # 1) Extraire le ZIP | |
| success, err = extract_zip(zip_dataset_path, DATASET_EXTRACT_DIR) | |
| if not success: | |
| return err | |
| # 2) Trouver le vrai dossier racine du dataset | |
| true_root = find_true_dataset_root(DATASET_EXTRACT_DIR) | |
| # 3) Charger dataset HF | |
| try: | |
| dataset = load_dataset("imagefolder", data_dir=true_root) | |
| except Exception as e: | |
| return f"Erreur lors du chargement du dataset imagefolder : {e}" | |
| # Afficher colonnes détectées | |
| column_info = f"Colonnes détectées : {dataset['train'].column_names}\n" | |
| feature_info = f"Features : {dataset['train'].features}\n" | |
| debug_log = column_info + feature_info | |
| # Vérifier que la colonne label existe | |
| if "label" not in dataset["train"].column_names: | |
| return debug_log + "\nErreur : aucune colonne label détectée." | |
| label_names = dataset["train"].features["label"].names | |
| num_labels = len(label_names) | |
| # 4) Préprocesseur | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| # Détecter la colonne image réellement présente | |
| def detect_image_key(keys): | |
| if "image" in keys: | |
| return "image" | |
| if "file" in keys: | |
| return "file" | |
| if "path" in keys: | |
| return "path" | |
| # fallback: première colonne non-label | |
| for k in keys: | |
| if k != "label": | |
| return k | |
| raise KeyError(f"Aucune colonne image trouvée dans {keys}") | |
| image_key = detect_image_key(dataset["train"].column_names) | |
| # 5) Transformation robuste | |
| def transform(batch): | |
| raw_imgs = batch[image_key] | |
| pil_images = [] | |
| for elem in raw_imgs: | |
| if isinstance(elem, Image.Image): | |
| pil_images.append(elem.convert("RGB")) | |
| else: | |
| pil_images.append(Image.open(elem).convert("RGB")) | |
| inputs = processor(pil_images, return_tensors="pt") | |
| inputs["labels"] = batch["label"] | |
| return inputs | |
| dataset = dataset.with_transform(transform) | |
| # 6) Charger modèle pré-entraîné | |
| model = AutoModelForImageClassification.from_pretrained( | |
| model_name, | |
| num_labels=num_labels, | |
| id2label={i: n for i, n in enumerate(label_names)}, | |
| label2id={n: i for i, n in enumerate(label_names)}, | |
| ignore_mismatched_sizes=True, # indispensable pour adapter 1000 → 2 classes | |
| ) | |
| # 7) Métrique | |
| metric = evaluate.load("accuracy") | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| return metric.compute(predictions=preds, references=labels) | |
| # 8) TrainingArguments | |
| args = TrainingArguments( | |
| output_dir=TRAIN_OUTPUT_DIR, | |
| num_train_epochs=int(epochs), | |
| per_device_train_batch_size=int(batch_size), | |
| per_device_eval_batch_size=int(batch_size), | |
| learning_rate=float(lr), | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| report_to=[], | |
| ) | |
| # 9) Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset.get("validation", dataset["train"]), | |
| compute_metrics=compute_metrics, | |
| ) | |
| trainer.train() | |
| # 10) Sauvegarde finale | |
| model.save_pretrained(TRAIN_OUTPUT_DIR) | |
| processor.save_pretrained(TRAIN_OUTPUT_DIR) | |
| return debug_log + f"\nEntraînement terminé. Modèle sauvegardé dans : {TRAIN_OUTPUT_DIR}" | |
| # ===================================================================== | |
| # PAGE 2 : INFÉRENCE | |
| # ===================================================================== | |
| def extract_model(zip_model_path): | |
| if os.path.isdir(MODEL_UPLOAD_DIR): | |
| shutil.rmtree(MODEL_UPLOAD_DIR) | |
| os.makedirs(MODEL_UPLOAD_DIR, exist_ok=True) | |
| try: | |
| with zipfile.ZipFile(zip_model_path, 'r') as zf: | |
| zf.extractall(MODEL_UPLOAD_DIR) | |
| return True, None | |
| except Exception as e: | |
| return False, f"Erreur extraction modèle : {e}" | |
| def predict(model_zip_path, image): | |
| if model_zip_path is None: | |
| return "Erreur : aucun modèle ZIP fourni." | |
| success, err = extract_model(model_zip_path) | |
| if not success: | |
| return err | |
| try: | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_UPLOAD_DIR) | |
| processor = AutoImageProcessor.from_pretrained(MODEL_UPLOAD_DIR) | |
| except Exception as e: | |
| return f"Erreur lors du chargement du modèle : {e}" | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| inputs = processor(images=image.convert("RGB"), return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy() | |
| id2label = model.config.id2label | |
| return {id2label[i]: float(probs[i]) for i in range(len(probs))} | |
| # ===================================================================== | |
| # INTERFACE GRADIO | |
| # ===================================================================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Classification de Soudures — Entraînement & Inférence") | |
| # ------------------------------------------------------------- | |
| # ONGLET 1 : ENTRAÎNEMENT | |
| # ------------------------------------------------------------- | |
| with gr.Tab("1 • Entraîner un modèle"): | |
| dataset_zip = gr.File(label="Dataset ZIP (Bonne/ et Mauvaise/)", type="filepath") | |
| model_name = gr.Textbox(label="Modèle de départ", value=DEFAULT_MODEL_NAME) | |
| epochs = gr.Slider(label="Époques", minimum=1, maximum=50, value=5) | |
| batch = gr.Slider(label="Batch size", minimum=2, maximum=64, value=8) | |
| lr = gr.Number(label="Learning rate", value=5e-5) | |
| train_btn = gr.Button("Lancer l'entraînement") | |
| train_log = gr.Textbox(label="Logs", lines=10) | |
| train_btn.click( | |
| train_model, | |
| inputs=[dataset_zip, model_name, epochs, batch, lr], | |
| outputs=train_log | |
| ) | |
| # ------------------------------------------------------------- | |
| # ONGLET 2 : INFÉRENCE | |
| # ------------------------------------------------------------- | |
| with gr.Tab("2 • Tester un modèle"): | |
| model_zip = gr.File(label="Modèle ZIP", type="filepath") | |
| input_image = gr.Image(label="Image de soudure") | |
| predict_btn = gr.Button("Prédire") | |
| result = gr.Label(label="Résultat") | |
| predict_btn.click( | |
| predict, | |
| inputs=[model_zip, input_image], | |
| outputs=result | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |