razmanitra commited on
Commit
27418d4
·
verified ·
1 Parent(s): 18c6d1d

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +93 -0
train.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import Trainer, TrainingArguments
3
+ from datasets import load_dataset
4
+ import torch
5
+ import torch.nn as nn
6
+ import os
7
+ import sys
8
+ import logging
9
+
10
+ # Configuration du logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Importer les modules n�cessaires
15
+ from modeling_recipe_generator import RecipeGeneratorForHF, RecipeGeneratorConfig
16
+
17
+ def main():
18
+ logger.info("Chargement du dataset depuis Hugging Face...")
19
+ # Charger le dataset depuis Hugging Face
20
+ dataset = load_dataset("razmanitra/recettes-dataset")
21
+
22
+ logger.info("Chargement de la configuration du mod�le...")
23
+ # Charger la configuration
24
+ config = RecipeGeneratorConfig.from_pretrained("razmanitra/recettes-generator")
25
+
26
+ logger.info("Cr�ation du mod�le...")
27
+ # Cr�er le mod�le
28
+ model = RecipeGeneratorForHF(config)
29
+
30
+ logger.info("Configuration des arguments d'entra�nement...")
31
+ # D�finir les arguments d'entra�nement
32
+ training_args = TrainingArguments(
33
+ output_dir="./results",
34
+ num_train_epochs=3,
35
+ per_device_train_batch_size=8,
36
+ per_device_eval_batch_size=8,
37
+ warmup_steps=500,
38
+ weight_decay=0.01,
39
+ logging_dir="./logs",
40
+ logging_steps=100,
41
+ evaluation_strategy="steps",
42
+ eval_steps=500,
43
+ save_steps=1000,
44
+ save_total_limit=2,
45
+ learning_rate=5e-5,
46
+ fp16=True, # Utiliser la pr�cision mixte
47
+ gradient_accumulation_steps=4,
48
+ push_to_hub=True,
49
+ hub_model_id="razmanitra/recettes-generator",
50
+ hub_strategy="every_save"
51
+ )
52
+
53
+ # D�finir la fonction de calcul des m�triques
54
+ def compute_metrics(eval_pred):
55
+ predictions, labels = eval_pred
56
+
57
+ # Calculer la perplexit�
58
+ loss = torch.nn.functional.cross_entropy(
59
+ torch.tensor(predictions).view(-1, predictions.shape[-1]),
60
+ torch.tensor(labels).view(-1)
61
+ )
62
+ perplexity = torch.exp(loss)
63
+
64
+ return {
65
+ "perplexity": perplexity.item()
66
+ }
67
+
68
+ logger.info("Initialisation du Trainer...")
69
+ # Initialiser le trainer
70
+ trainer = Trainer(
71
+ model=model,
72
+ args=training_args,
73
+ train_dataset=dataset["train"],
74
+ eval_dataset=dataset["test"],
75
+ compute_metrics=compute_metrics
76
+ )
77
+
78
+ logger.info("D�marrage de l'entra�nement...")
79
+ # Lancer l'entra�nement
80
+ trainer.train()
81
+
82
+ logger.info("Entra�nement termin�. Sauvegarde du mod�le final...")
83
+ # Sauvegarder le mod�le final
84
+ trainer.save_model("./final_model")
85
+
86
+ logger.info("Envoi du mod�le final vers Hugging Face Hub...")
87
+ # Pousser le mod�le final vers Hugging Face
88
+ trainer.push_to_hub()
89
+
90
+ logger.info("Processus d'entra�nement termin� avec succ�s!")
91
+
92
+ if __name__ == "__main__":
93
+ main()