recettes-generator / train.py
razmanitra's picture
Upload train.py with huggingface_hub
27418d4 verified
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
import torch
import torch.nn as nn
import os
import sys
import logging
# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Importer les modules n�cessaires
from modeling_recipe_generator import RecipeGeneratorForHF, RecipeGeneratorConfig
def main():
logger.info("Chargement du dataset depuis Hugging Face...")
# Charger le dataset depuis Hugging Face
dataset = load_dataset("razmanitra/recettes-dataset")
logger.info("Chargement de la configuration du mod�le...")
# Charger la configuration
config = RecipeGeneratorConfig.from_pretrained("razmanitra/recettes-generator")
logger.info("Cr�ation du mod�le...")
# Cr�er le mod�le
model = RecipeGeneratorForHF(config)
logger.info("Configuration des arguments d'entra�nement...")
# D�finir les arguments d'entra�nement
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=100,
evaluation_strategy="steps",
eval_steps=500,
save_steps=1000,
save_total_limit=2,
learning_rate=5e-5,
fp16=True, # Utiliser la pr�cision mixte
gradient_accumulation_steps=4,
push_to_hub=True,
hub_model_id="razmanitra/recettes-generator",
hub_strategy="every_save"
)
# D�finir la fonction de calcul des m�triques
def compute_metrics(eval_pred):
predictions, labels = eval_pred
# Calculer la perplexit�
loss = torch.nn.functional.cross_entropy(
torch.tensor(predictions).view(-1, predictions.shape[-1]),
torch.tensor(labels).view(-1)
)
perplexity = torch.exp(loss)
return {
"perplexity": perplexity.item()
}
logger.info("Initialisation du Trainer...")
# Initialiser le trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
compute_metrics=compute_metrics
)
logger.info("D�marrage de l'entra�nement...")
# Lancer l'entra�nement
trainer.train()
logger.info("Entra�nement termin�. Sauvegarde du mod�le final...")
# Sauvegarder le mod�le final
trainer.save_model("./final_model")
logger.info("Envoi du mod�le final vers Hugging Face Hub...")
# Pousser le mod�le final vers Hugging Face
trainer.push_to_hub()
logger.info("Processus d'entra�nement termin� avec succ�s!")
if __name__ == "__main__":
main()