Spaces:
Sleeping
Sleeping
| # train.py | |
| from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer | |
| from transformers import DefaultDataCollator | |
| from datasets import load_dataset, Image | |
| import torch | |
| # 1. Charger le dataset et le mapper aux classes | |
| dataset = load_dataset("ashraq/fashion-product-images-small", name="styles", split="train") | |
| dataset = dataset.train_test_split(test_size=0.2) | |
| train_ds = dataset["train"] | |
| test_ds = dataset["test"] | |
| # 2. Créer la liste des labels (catégories uniques) | |
| labels = train_ds.unique("articleType") | |
| label2id, id2label = {}, {} | |
| for i, label in enumerate(labels): | |
| label2id[label] = i | |
| id2label[i] = label | |
| # 3. Charger le processeur et le modèle de base CORRECTS | |
| # On prend un modèle pré-entraîné sur ImageNet, pas sur des haricots ! | |
| model_ckpt = "google/vit-base-patch16-224" | |
| processor = ViTImageProcessor.from_pretrained(model_ckpt) | |
| model = ViTForImageClassification.from_pretrained( | |
| model_ckpt, | |
| num_labels=len(labels), | |
| id2label=id2label, | |
| label2id=label2id, | |
| ignore_mismatched_sizes=True # Important car le nombre de classes change | |
| ) | |
| # 4. Fonction de preprocessing pour transformer les images | |
| def transform(example_batch): | |
| inputs = processor([Image.open(img).convert("RGB") for img in example_batch["image_path"]], return_tensors="pt") | |
| inputs["labels"] = [label2id[label] for label in example_batch["articleType"]] | |
| return inputs | |
| # Appliquer le preprocessing | |
| train_ds = train_ds.cast_column("image_path", Image()) | |
| test_ds = test_ds.cast_column("image_path", Image()) | |
| train_ds.set_transform(transform) | |
| test_ds.set_transform(transform) | |
| # 5. Définir les arguments d'entraînement | |
| training_args = TrainingArguments( | |
| output_dir="./vit-fashion-classifier", | |
| per_device_train_batch_size=16, | |
| evaluation_strategy="steps", | |
| num_train_epochs=4, | |
| fp16=True, | |
| save_steps=100, | |
| eval_steps=100, | |
| logging_steps=10, | |
| learning_rate=2e-4, | |
| save_total_limit=2, | |
| remove_unused_columns=False, | |
| push_to_hub=True, # Pour pousser directement sur votre HF Space après l'entraînement | |
| hub_model_id="MODLI/vit-fashion-classifier", # Remplacez par votre repo | |
| ) | |
| # 6. Lancer l'entraînement | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=DefaultDataCollator(), | |
| train_dataset=train_ds, | |
| eval_dataset=test_ds, | |
| tokenizer=processor, | |
| ) | |
| trainer.train() | |
| trainer.push_to_hub() |