Test_soudure / app.py
ESMIEU Nathan OBS/OBF
update app.py
4362883
import gradio as gr
import os
import shutil
import zipfile
from PIL import Image
import torch
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
TrainingArguments,
Trainer,
)
# =====================================================================
# CONFIG GLOBALE
# =====================================================================
DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224"
TRAIN_OUTPUT_DIR = "trained_model"
DATASET_EXTRACT_DIR = "dataset_extracted"
MODEL_UPLOAD_DIR = "uploaded_model"
# =====================================================================
# OUTILS : EXTRACTION ZIP
# =====================================================================
def extract_zip(zip_path, dest_dir):
"""Extrait un ZIP et remplace le dossier existant."""
if os.path.isdir(dest_dir):
shutil.rmtree(dest_dir)
os.makedirs(dest_dir, exist_ok=True)
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
zf.extractall(dest_dir)
return True, None
except Exception as e:
return False, f"Erreur extraction ZIP : {e}"
def find_true_dataset_root(root):
"""
Trouve automatiquement le vrai dossier contenant les classes :
- Bonne/
- Mauvaise/
Même si le ZIP contient une couche inutile :
dataset.zip
images de soudures/
bonne/
mauvaise/
Cette fonction retourne le dossier qui contient réellement les classes.
"""
subdirs = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
# Cas idéal : les classes sont directement présentes
if any(name.lower() in ["bonne", "mauvaise"] for name in subdirs):
return root
# Sinon, entrer dans le premier sous-dossier
if len(subdirs) == 1:
sub = os.path.join(root, subdirs[0])
deeper = [d for d in os.listdir(sub) if os.path.isdir(os.path.join(sub, d))]
if any(name.lower() in ["bonne", "mauvaise"] for name in deeper):
return sub
return root # fallback
# =====================================================================
# PAGE 1 : ENTRAÎNEMENT
# =====================================================================
def train_model(zip_dataset_path, model_name, epochs, batch_size, lr):
if zip_dataset_path is None:
return "Erreur : aucun dataset ZIP fourni."
# 1) Extraire le ZIP
success, err = extract_zip(zip_dataset_path, DATASET_EXTRACT_DIR)
if not success:
return err
# 2) Trouver le vrai dossier racine du dataset
true_root = find_true_dataset_root(DATASET_EXTRACT_DIR)
# 3) Charger dataset HF
try:
dataset = load_dataset("imagefolder", data_dir=true_root)
except Exception as e:
return f"Erreur lors du chargement du dataset imagefolder : {e}"
# Afficher colonnes détectées
column_info = f"Colonnes détectées : {dataset['train'].column_names}\n"
feature_info = f"Features : {dataset['train'].features}\n"
debug_log = column_info + feature_info
# Vérifier que la colonne label existe
if "label" not in dataset["train"].column_names:
return debug_log + "\nErreur : aucune colonne label détectée."
label_names = dataset["train"].features["label"].names
num_labels = len(label_names)
# 4) Préprocesseur
processor = AutoImageProcessor.from_pretrained(model_name)
# Détecter la colonne image réellement présente
def detect_image_key(keys):
if "image" in keys:
return "image"
if "file" in keys:
return "file"
if "path" in keys:
return "path"
# fallback: première colonne non-label
for k in keys:
if k != "label":
return k
raise KeyError(f"Aucune colonne image trouvée dans {keys}")
image_key = detect_image_key(dataset["train"].column_names)
# 5) Transformation robuste
def transform(batch):
raw_imgs = batch[image_key]
pil_images = []
for elem in raw_imgs:
if isinstance(elem, Image.Image):
pil_images.append(elem.convert("RGB"))
else:
pil_images.append(Image.open(elem).convert("RGB"))
inputs = processor(pil_images, return_tensors="pt")
inputs["labels"] = batch["label"]
return inputs
dataset = dataset.with_transform(transform)
# 6) Charger modèle pré-entraîné
model = AutoModelForImageClassification.from_pretrained(
model_name,
num_labels=num_labels,
id2label={i: n for i, n in enumerate(label_names)},
label2id={n: i for i, n in enumerate(label_names)},
ignore_mismatched_sizes=True, # indispensable pour adapter 1000 → 2 classes
)
# 7) Métrique
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
return metric.compute(predictions=preds, references=labels)
# 8) TrainingArguments
args = TrainingArguments(
output_dir=TRAIN_OUTPUT_DIR,
num_train_epochs=int(epochs),
per_device_train_batch_size=int(batch_size),
per_device_eval_batch_size=int(batch_size),
learning_rate=float(lr),
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
report_to=[],
)
# 9) Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset.get("validation", dataset["train"]),
compute_metrics=compute_metrics,
)
trainer.train()
# 10) Sauvegarde finale
model.save_pretrained(TRAIN_OUTPUT_DIR)
processor.save_pretrained(TRAIN_OUTPUT_DIR)
return debug_log + f"\nEntraînement terminé. Modèle sauvegardé dans : {TRAIN_OUTPUT_DIR}"
# =====================================================================
# PAGE 2 : INFÉRENCE
# =====================================================================
def extract_model(zip_model_path):
if os.path.isdir(MODEL_UPLOAD_DIR):
shutil.rmtree(MODEL_UPLOAD_DIR)
os.makedirs(MODEL_UPLOAD_DIR, exist_ok=True)
try:
with zipfile.ZipFile(zip_model_path, 'r') as zf:
zf.extractall(MODEL_UPLOAD_DIR)
return True, None
except Exception as e:
return False, f"Erreur extraction modèle : {e}"
def predict(model_zip_path, image):
if model_zip_path is None:
return "Erreur : aucun modèle ZIP fourni."
success, err = extract_model(model_zip_path)
if not success:
return err
try:
model = AutoModelForImageClassification.from_pretrained(MODEL_UPLOAD_DIR)
processor = AutoImageProcessor.from_pretrained(MODEL_UPLOAD_DIR)
except Exception as e:
return f"Erreur lors du chargement du modèle : {e}"
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
inputs = processor(images=image.convert("RGB"), return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
id2label = model.config.id2label
return {id2label[i]: float(probs[i]) for i in range(len(probs))}
# =====================================================================
# INTERFACE GRADIO
# =====================================================================
with gr.Blocks() as demo:
gr.Markdown("# Classification de Soudures — Entraînement & Inférence")
# -------------------------------------------------------------
# ONGLET 1 : ENTRAÎNEMENT
# -------------------------------------------------------------
with gr.Tab("1 • Entraîner un modèle"):
dataset_zip = gr.File(label="Dataset ZIP (Bonne/ et Mauvaise/)", type="filepath")
model_name = gr.Textbox(label="Modèle de départ", value=DEFAULT_MODEL_NAME)
epochs = gr.Slider(label="Époques", minimum=1, maximum=50, value=5)
batch = gr.Slider(label="Batch size", minimum=2, maximum=64, value=8)
lr = gr.Number(label="Learning rate", value=5e-5)
train_btn = gr.Button("Lancer l'entraînement")
train_log = gr.Textbox(label="Logs", lines=10)
train_btn.click(
train_model,
inputs=[dataset_zip, model_name, epochs, batch, lr],
outputs=train_log
)
# -------------------------------------------------------------
# ONGLET 2 : INFÉRENCE
# -------------------------------------------------------------
with gr.Tab("2 • Tester un modèle"):
model_zip = gr.File(label="Modèle ZIP", type="filepath")
input_image = gr.Image(label="Image de soudure")
predict_btn = gr.Button("Prédire")
result = gr.Label(label="Résultat")
predict_btn.click(
predict,
inputs=[model_zip, input_image],
outputs=result
)
if __name__ == "__main__":
demo.launch()