Test_soudure / app.py
ESMIEU Nathan OBS/OBF
Update app.py and training logic
a869dd5
raw
history blame
7.71 kB
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
from PIL import Image
import numpy as np
import gradio as gr
import torch
import os
import shutil
import zipfile
import evaluate # pour les métriques
DEFAULT_MODEL_NAME = "facebook/convnextv2-tiny-1k-224" # ou "google/vit-base-patch16-224"
DEFAULT_OUTPUT_DIR = "./weld_cls_model_best"
EXTRACT_DIR = "./uploaded_dataset" # dossier où l'on extrait l'archive
def extract_archive(archive_path, extract_to=EXTRACT_DIR):
"""
Extrait une archive .zip ou .rar dans extract_to.
La structure attendue après extraction est de type imagefolder :
extract_to/
bonne/
img1.jpg
...
mauvaise/
img2.jpg
...
"""
if archive_path is None or not os.path.isfile(archive_path):
return None, f"Erreur : aucune archive de dataset fournie."
# Nettoyer l'ancien dossier, s'il existe
if os.path.isdir(extract_to):
shutil.rmtree(extract_to)
os.makedirs(extract_to, exist_ok=True)
archive_lower = archive_path.lower()
try:
if archive_lower.endswith(".zip"):
with zipfile.ZipFile(archive_path, "r") as zf:
zf.extractall(extract_to)
elif archive_lower.endswith(".rar"):
try:
import rarfile
except ImportError:
return None, (
"Erreur : format .rar demandé mais le module 'rarfile' n'est pas installé.\n"
"Ajoute 'rarfile' dans requirements.txt, ou utilise une archive .zip."
)
with rarfile.RarFile(archive_path) as rf:
rf.extractall(extract_to)
else:
return None, "Erreur : format d'archive non supporté. Utilise .zip ou .rar."
except Exception as e:
return None, f"Erreur lors de l'extraction de l'archive : {e}"
return extract_to, None
def train_model(dataset_archive_path, model_name=DEFAULT_MODEL_NAME, num_epochs=10, batch_size=16, lr=5e-5):
"""
Lance l'entraînement à partir d'une archive uploadée (zip/rar) contenant
un dataset de type imagefolder.
"""
# 0) Extraction de l'archive
data_dir, err = extract_archive(dataset_archive_path)
if err is not None:
return err
if data_dir is None or not os.path.isdir(data_dir):
return f"Erreur : le dossier de données '{data_dir}' est introuvable après extraction."
# 1) Charger le dataset
try:
dataset = load_dataset("imagefolder", data_dir=data_dir)
except Exception as e:
return f"Erreur lors du chargement du dataset avec 'imagefolder' : {e}"
label_names = dataset["train"].features["label"].names
num_labels = len(label_names)
# 2) Choisir le modèle HF
processor = AutoImageProcessor.from_pretrained(model_name)
def transform(example_batch):
images = [
x.convert("RGB") if isinstance(x, Image.Image) else x
for x in example_batch["image"]
]
inputs = processor(images, return_tensors="pt")
inputs["labels"] = example_batch["label"]
return inputs
prepared_ds = dataset.with_transform(transform)
# 3) Charger le modèle de classification
model = AutoModelForImageClassification.from_pretrained(
model_name,
num_labels=num_labels,
id2label={i: l for i, l in enumerate(label_names)},
label2id={l: i for i, l in enumerate(label_names)},
)
# 4) Définir les métriques
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
return metric.compute(predictions=preds, references=labels)
# 5) Définir le Trainer
training_args = TrainingArguments(
output_dir="./weld_cls_model",
per_device_train_batch_size=int(batch_size),
per_device_eval_batch_size=int(batch_size),
learning_rate=float(lr),
num_train_epochs=int(num_epochs),
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
logging_steps=50,
report_to=[], # désactive Weights & Biases dans un Space si non configuré
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=prepared_ds["train"],
eval_dataset=prepared_ds.get("validation", prepared_ds["train"]),
compute_metrics=compute_metrics,
)
# 6) Entraînement
trainer.train()
# 7) Sauvegarde
trainer.save_model(DEFAULT_OUTPUT_DIR)
processor.save_pretrained(DEFAULT_OUTPUT_DIR)
return f"Entraînement terminé. Modèle sauvegardé dans : {DEFAULT_OUTPUT_DIR}"
def predict(image):
"""
Inférence : prend une image et renvoie les probabilités par classe.
Utilise le modèle sauvegardé dans DEFAULT_OUTPUT_DIR.
"""
if not os.path.isdir(DEFAULT_OUTPUT_DIR):
return "Erreur : aucun modèle entraîné trouvé. Lance d'abord l'entraînement."
# Charger modèle + processor
model = AutoModelForImageClassification.from_pretrained(DEFAULT_OUTPUT_DIR)
processor = AutoImageProcessor.from_pretrained(DEFAULT_OUTPUT_DIR)
model.eval()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
inputs = processor(images=image.convert("RGB"), return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits.softmax(dim=-1)[0].cpu().numpy()
id2label = model.config.id2label
result = {id2label[i]: float(probs[i]) for i in range(len(probs))}
return result
# -----------------------
# Interface Gradio
# -----------------------
with gr.Blocks() as demo:
gr.Markdown("# Classification de soudures – Entraînement + Inférence\n"
"Upload d'un dataset (.zip ou .rar), entraînement du modèle, puis test sur des images.")
with gr.Tab("Entraînement"):
gr.Markdown("## Lancer l'entraînement")
dataset_file_input = gr.File(
label="Archive du dataset (.zip ou .rar)",
type="filepath"
)
model_name_input = gr.Textbox(
label="Nom du modèle Hugging Face",
value=DEFAULT_MODEL_NAME
)
epochs_input = gr.Slider(
label="Nombre d'époques",
minimum=1,
maximum=50,
value=10,
step=1
)
batch_input = gr.Slider(
label="Batch size",
minimum=4,
maximum=64,
value=16,
step=4
)
lr_input = gr.Number(
label="Learning rate",
value=5e-5,
precision=7
)
train_button = gr.Button("Lancer l'entraînement")
train_output = gr.Textbox(label="Journal / Résultat de l'entraînement")
train_button.click(
fn=train_model,
inputs=[dataset_file_input, model_name_input, epochs_input, batch_input, lr_input],
outputs=train_output
)
with gr.Tab("Inférence"):
gr.Markdown("## Tester le modèle entraîné sur une image de soudure")
image_input = gr.Image(label="Image de soudure", type="pil")
pred_button = gr.Button("Prédire")
pred_output = gr.Label(label="Probabilités par classe")
pred_button.click(
fn=predict,
inputs=image_input,
outputs=pred_output
)
if __name__ == "__main__":
demo.launch()