| | import os |
| | import json |
| | import sys |
| | from datetime import datetime |
| | import traceback |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer |
| | from datasets import load_dataset |
| | import torch |
| | import pandas as pd |
| | from huggingface_hub import login |
| | from connect_huggingface import setup_huggingface |
| | import gradio as gr |
| |
|
| | class TrainingCallback: |
| | def __init__(self): |
| | self.logs = [] |
| | |
| | def on_log(self, args, state, control, logs=None, **kwargs): |
| | if logs: |
| | self.logs.append(logs) |
| | |
| | def get_logs(self): |
| | return "\n".join([str(log) for log in self.logs]) |
| |
|
| | def start_training(): |
| | try: |
| | |
| | if not setup_huggingface(): |
| | return "Erreur : Impossible de configurer Hugging Face", "### Logs d'entraînement\nErreur de configuration Hugging Face" |
| | |
| | |
| | status = "Configuration de l'environnement..." |
| | logs = f"### Logs d'entraînement\nDémarrage à {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" |
| | |
| | |
| | with open('config.json', 'r') as f: |
| | config = json.load(f) |
| | |
| | |
| | logs += "- Chargement du modèle et du tokenizer...\n" |
| | tokenizer = AutoTokenizer.from_pretrained(config['model']['tokenizer']) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | config['model']['name'], |
| | torch_dtype=torch.bfloat16 if config['training']['bf16'] else torch.float32, |
| | device_map="auto" |
| | ) |
| | |
| | |
| | status = "Chargement du dataset..." |
| | logs += f"- Chargement de {config['dataset']['name']}...\n" |
| | dataset = load_dataset(config['dataset']['name']) |
| | |
| | |
| | status = "Configuration de l'entraînement..." |
| | logs += "- Configuration des paramètres d'entraînement...\n" |
| | |
| | training_args = TrainingArguments( |
| | output_dir="./results", |
| | num_train_epochs=config['training']['epochs'], |
| | per_device_train_batch_size=config['training']['batch_size'], |
| | learning_rate=config['training']['learning_rate'], |
| | warmup_ratio=config['training']['warmup_ratio'], |
| | evaluation_strategy=config['training']['evaluation_strategy'], |
| | eval_steps=config['training']['eval_steps'], |
| | save_strategy=config['training']['save_strategy'], |
| | save_steps=config['training']['save_steps'], |
| | save_total_limit=config['training']['save_total_limit'], |
| | load_best_model_at_end=config['training']['load_best_model_at_end'], |
| | metric_for_best_model=config['training']['metric_for_best_model'], |
| | greater_is_better=config['training']['greater_is_better'], |
| | gradient_accumulation_steps=config['training']['gradient_accumulation_steps'], |
| | logging_steps=config['training']['logging_steps'], |
| | fp16=config['training']['fp16'], |
| | bf16=config['training']['bf16'] |
| | ) |
| | |
| | |
| | callback = TrainingCallback() |
| | |
| | |
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=dataset[config['dataset']['train_split']], |
| | eval_dataset=dataset[config['dataset']['eval_split']], |
| | callbacks=[callback] |
| | ) |
| | |
| | |
| | status = "Entraînement en cours..." |
| | logs += "- Début de l'entraînement...\n" |
| | |
| | trainer.train() |
| | |
| | |
| | logs += "\n### Logs détaillés\n" |
| | logs += callback.get_logs() |
| | |
| | status = "Entraînement terminé avec succès!" |
| | logs += "\n\nEntraînement terminé avec succès!" |
| | |
| | return status, logs |
| | |
| | except Exception as e: |
| | error_msg = f"Erreur pendant l'entraînement : {str(e)}\n{traceback.format_exc()}" |
| | print(error_msg) |
| | return "Erreur pendant l'entraînement", f"### Logs d'entraînement\n❌ {error_msg}" |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=start_training, |
| | inputs=[], |
| | outputs=[gr.Textbox(label="Statut de l'entraînement"), gr.Markdown(label="Logs de l'entraînement")], |
| | title="AUTO Training Space", |
| | description="Cliquez sur le bouton pour lancer l'entraînement du modèle." |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|