|
|
|
|
|
from transformers import Trainer, TrainingArguments
|
|
|
from datasets import load_dataset
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import os
|
|
|
import sys
|
|
|
import logging
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
from modeling_recipe_generator import RecipeGeneratorForHF, RecipeGeneratorConfig
|
|
|
|
|
|
def main():
|
|
|
logger.info("Chargement du dataset depuis Hugging Face...")
|
|
|
|
|
|
dataset = load_dataset("razmanitra/recettes-dataset")
|
|
|
|
|
|
logger.info("Chargement de la configuration du mod�le...")
|
|
|
|
|
|
config = RecipeGeneratorConfig.from_pretrained("razmanitra/recettes-generator")
|
|
|
|
|
|
logger.info("Cr�ation du mod�le...")
|
|
|
|
|
|
model = RecipeGeneratorForHF(config)
|
|
|
|
|
|
logger.info("Configuration des arguments d'entra�nement...")
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
|
output_dir="./results",
|
|
|
num_train_epochs=3,
|
|
|
per_device_train_batch_size=8,
|
|
|
per_device_eval_batch_size=8,
|
|
|
warmup_steps=500,
|
|
|
weight_decay=0.01,
|
|
|
logging_dir="./logs",
|
|
|
logging_steps=100,
|
|
|
evaluation_strategy="steps",
|
|
|
eval_steps=500,
|
|
|
save_steps=1000,
|
|
|
save_total_limit=2,
|
|
|
learning_rate=5e-5,
|
|
|
fp16=True,
|
|
|
gradient_accumulation_steps=4,
|
|
|
push_to_hub=True,
|
|
|
hub_model_id="razmanitra/recettes-generator",
|
|
|
hub_strategy="every_save"
|
|
|
)
|
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred):
|
|
|
predictions, labels = eval_pred
|
|
|
|
|
|
|
|
|
loss = torch.nn.functional.cross_entropy(
|
|
|
torch.tensor(predictions).view(-1, predictions.shape[-1]),
|
|
|
torch.tensor(labels).view(-1)
|
|
|
)
|
|
|
perplexity = torch.exp(loss)
|
|
|
|
|
|
return {
|
|
|
"perplexity": perplexity.item()
|
|
|
}
|
|
|
|
|
|
logger.info("Initialisation du Trainer...")
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model=model,
|
|
|
args=training_args,
|
|
|
train_dataset=dataset["train"],
|
|
|
eval_dataset=dataset["test"],
|
|
|
compute_metrics=compute_metrics
|
|
|
)
|
|
|
|
|
|
logger.info("D�marrage de l'entra�nement...")
|
|
|
|
|
|
trainer.train()
|
|
|
|
|
|
logger.info("Entra�nement termin�. Sauvegarde du mod�le final...")
|
|
|
|
|
|
trainer.save_model("./final_model")
|
|
|
|
|
|
logger.info("Envoi du mod�le final vers Hugging Face Hub...")
|
|
|
|
|
|
trainer.push_to_hub()
|
|
|
|
|
|
logger.info("Processus d'entra�nement termin� avec succ�s!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|