Train-aricatev4 / app.py
Clemylia's picture
Create app.py
b88d755 verified
# app.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import collections
from datasets import load_dataset
from huggingface_hub import PyTorchModelHubMixin, HfApi, login
import os
import time
import json
import heapq
from safetensors.torch import save_file as save_safetensors_file
import gradio as gr
import sys # Pour la redirection de la sortie
# ==============================================================================
# ARCHITECTURE ARICATE V4 (Intégrée)
# ==============================================================================
# --- A. WordTokenizer ---
class WordTokenizer:
"""Tokenizer simple pour l'architecture Aricate."""
def __init__(self, texts):
all_words = []
for text in texts:
# S'assurer que 'text' est une chaîne de caractères avant de l'opérer
if isinstance(text, str):
words = text.lower().split()
all_words.extend(words)
word_counts = collections.Counter(all_words)
sorted_words = [word for word, count in word_counts.most_common()]
self.special_tokens = {
'<pad>': 0,
'<unk>': 1,
'<eos>': 2,
'<sep>': 3,
}
self.word_to_id = self.special_tokens.copy()
next_id = len(self.special_tokens)
for word in sorted_words:
if word not in self.word_to_id:
self.word_to_id[word] = next_id
next_id += 1
self.id_to_word = {id: word for word, id in self.word_to_id.items()}
self.vocab_size = len(self.word_to_id)
print(f"Tokenisation effectuée. Taille du vocabulaire : {self.vocab_size}")
def encode(self, text, add_eos=False):
words = text.lower().split()
if add_eos:
words.append('<eos>')
ids = [self.word_to_id.get(word, self.word_to_id['<unk>']) for word in words]
return ids
def decode(self, ids):
words = [self.id_to_word.get(id, '<unk>') for id in ids]
return " ".join(word for word in words if word not in ['<pad>', '<unk>', '<eos>', '<sep>'])
# --- B. AricateAttentionLayer ---
class AricateAttentionLayer(nn.Module):
"""Couche d'Attention Additive (Bahdanau)."""
def __init__(self, hidden_dim):
super(AricateAttentionLayer, self).__init__()
self.W = nn.Linear(hidden_dim, hidden_dim)
self.U = nn.Linear(hidden_dim, hidden_dim)
self.V = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, rnn_outputs, last_hidden):
last_hidden_expanded = last_hidden.unsqueeze(1)
energy = torch.tanh(self.W(rnn_outputs) + self.U(last_hidden_expanded))
attention_weights_raw = self.V(energy).squeeze(2)
attention_weights = F.softmax(attention_weights_raw, dim=1)
context_vector = torch.sum(rnn_outputs * attention_weights.unsqueeze(2), dim=1)
return context_vector
# --- C. AricateModel V4 ---
class AricateModel(nn.Module, PyTorchModelHubMixin):
"""Architecture Aricate V4. Hérite de PyTorchModelHubMixin pour la sauvegarde et la publication."""
def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int, num_layers: int = 1, config: dict = None):
super(AricateModel, self).__init__()
if config is not None:
vocab_size = config.get("vocab_size", vocab_size)
embedding_dim = config.get("embedding_dim", embedding_dim)
hidden_dim = config.get("hidden_dim", hidden_dim)
num_layers = config.get("num_layers", num_layers)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.word_embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
self.attention = AricateAttentionLayer(hidden_dim)
self.hidden_to_vocab = nn.Linear(hidden_dim * 2, vocab_size)
def forward(self, input_words):
embeds = self.word_embeddings(input_words)
rnn_out, hn = self.rnn(embeds)
last_hidden = hn[-1]
context_vector = self.attention(rnn_out, last_hidden)
combined_features = torch.cat((context_vector, last_hidden), dim=1)
logits = self.hidden_to_vocab(combined_features)
return logits
# --- D. Fonction de Génération (Simplifiée pour l'espace) ---
# NOTE: J'ai retiré la fonction de génération pour ne pas alourdir l'application Gradio principale et me concentrer sur l'entraînement/publication.
# Dans un Space, il est préférable d'avoir une démo séparée après l'entraînement.
# Je garde le Dataset car c'est nécessaire.
# --- Nouvelle Classe PyTorch Dataset ---
class AricateDataset(Dataset):
"""Dataset personnalisé pour PyTorch."""
def __init__(self, X_data, Y_data):
self.X = X_data
self.Y = Y_data
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.Y[idx]
# ==============================================================================
# FONCTION D'ENTRAÎNEMENT ADAPTÉE POUR GRADIO
# ==============================================================================
def train_aricate_model(
hf_token: str,
hf_user: str,
dataset_name: str,
question_col: str,
response_col: str,
model_name: str,
num_epochs: int
):
"""
Fonction principale d'entraînement adaptée pour Gradio.
Elle prend les entrées de l'utilisateur, configure Aricate v4,
lance l'entraînement et publie le modèle sur Hugging Face.
"""
# Rediriger la sortie standard vers la console Gradio
sys.stdout.flush()
print(f"\n{'='*50}\n>>> DÉBUT DU PROCESSUS D'ENTRAÎNEMENT Aricat v4 <<<\n{'='*50}")
try:
# --- 0. Configuration & Connexion Hugging Face ---
# Paramètres fixes (peuvent être ajustés si nécessaire)
EMBEDDING_DIM = 64
HIDDEN_DIM = 128
NUM_LAYERS = 2
BATCH_SIZE = 128
LEARNING_RATE = 0.005
# Configuration de l'appareil
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Appareil d'entraînement sélectionné: {device}")
# Connexion Hugging Face via le token
login(token=hf_token, add_to_git_credential=False)
REPO_ID = f"{hf_user}/{model_name}"
print(f"Connexion Hugging Face établie. Le modèle sera publié sous le dépôt: {REPO_ID}")
print(f"--- Lancement de l'Entraînement du SLM '{model_name}' (Aricate) ---")
# 1. Préparation des données
DATASET_SPLIT = 'train'
print(f"Chargement de la dataset '{dataset_name}' (split '{DATASET_SPLIT}')...")
try:
dataset = load_dataset(dataset_name, split=DATASET_SPLIT)
except Exception as e:
raise ValueError(f"Erreur lors du chargement de la dataset '{dataset_name}'. Vérifiez le nom du dépôt. Erreur: {e}")
# Construction du corpus en utilisant les colonnes spécifiées par l'utilisateur
try:
corpus_raw = [f"{ex[question_col]} <sep> {ex[response_col]}" for ex in dataset]
except KeyError as e:
raise KeyError(f"Colonne introuvable dans la dataset. Vérifiez les noms de colonnes : {e}. Les colonnes de votre dataset sont : {dataset.column_names}")
tokenizer = WordTokenizer(corpus_raw)
train_data_X = []
train_data_Y = []
for item in dataset:
q = item[question_col]
r = item[response_col]
full_seq_ids = tokenizer.encode(f"{q} <sep> {r}", add_eos=True)
for i in range(1, len(full_seq_ids)):
X = full_seq_ids[:i]
Y = full_seq_ids[i]
train_data_X.append(X)
train_data_Y.append(Y)
max_len = max(len(x) for x in train_data_X)
padded_X = []
for x in train_data_X:
padding_needed = max_len - len(x)
# Ajout du padding au DÉBUT de la séquence (convention de certains modèles pour l'alignement)
padded_X.append([tokenizer.special_tokens['<pad>']] * padding_needed + x)
X_train_tensor = torch.tensor(padded_X)
Y_train_tensor = torch.tensor(train_data_Y)
VOCAB_SIZE = tokenizer.vocab_size
print(f"Dataset chargée. Nombre de paires d'entraînement: {len(Y_train_tensor)}")
print(f"Taille du vocabulaire total: {VOCAB_SIZE}")
print(f"Longueur maximale d'entrée (max_len): {max_len}")
aricate_dataset = AricateDataset(X_train_tensor, Y_train_tensor)
train_loader = DataLoader(
dataset=aricate_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0 # Mis à 0 pour éviter des problèmes de multi-processus sur certains environnements HF Space
)
print(f"Nombre de batches par époque : {len(train_loader)}")
# 2. Initialisation du Modèle
model_config = {
"vocab_size": VOCAB_SIZE,
"embedding_dim": EMBEDDING_DIM,
"hidden_dim": HIDDEN_DIM,
"num_layers": NUM_LAYERS
}
model = AricateModel(**model_config).to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 3. Entraînement
print(f"\nDébut de l'entraînement pour {num_epochs} époques avec un BATCH_SIZE de {BATCH_SIZE}...")
start_time = time.time()
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
for batch_X, batch_Y in train_loader:
batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
optimizer.zero_grad()
logits = model(batch_X)
loss = loss_function(logits, batch_Y)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch_X.size(0)
avg_loss = total_loss / len(aricate_dataset)
# Mise à jour immédiate du statut
yield f"Entraînement en cours... Époque [{epoch+1}/{num_epochs}], Perte Moyenne: {avg_loss:.4f}"
end_time = time.time()
yield f"Entraînement terminé ! Durée: {(end_time - start_time):.2f}s. Début de la publication..."
print(f"\nEntraînement terminé ! Durée: {(end_time - start_time):.2f}s 🎉")
# 4. Sauvegarde et Publication sur Hugging Face
model.to("cpu")
print("\n" + "="*50)
print(">>> SAUVEGARDE ET PUBLICATION SUR HUGGING FACE <<<")
print("="*50)
save_directory = f"./{model_name}_local_save"
os.makedirs(save_directory, exist_ok=True)
model.save_pretrained(save_directory)
print(f"Modèle sauvegardé localement dans: {save_directory}")
tokenizer_path = os.path.join(save_directory, "aricate_tokenizer.txt")
with open(tokenizer_path, 'w', encoding='utf-8') as f:
json.dump(tokenizer.word_to_id, f, ensure_ascii=False)
print(f"Tokenizer (vocabulaire) sauvegardé dans: {tokenizer_path}")
# Publication
model.push_to_hub(
repo_id=REPO_ID,
commit_message=f"Modèle entraîné via Aricate v4 Space. Époques: {num_epochs}",
config=model_config
)
HfApi().upload_file(
path_or_fileobj=tokenizer_path,
path_in_repo="aricate_tokenizer.txt",
repo_id=REPO_ID,
repo_type="model",
commit_message="Update Aricate custom tokenizer vocabulary."
)
final_message = f"\n✅ Publication réussie ! Votre modèle '{model_name}' est disponible sur : https://huggingface.co/{REPO_ID}"
print(final_message)
yield final_message # Message final pour l'interface Gradio
except Exception as e:
error_message = f"\n❌ ERREUR CRITIQUE. L'entraînement ou la publication a échoué. Détail: {e}"
print(error_message)
yield error_message # Message d'erreur pour l'interface Gradio
# ==============================================================================
# INTERFACE GRADIO
# ==============================================================================
# Description détaillée pour l'utilisateur
description = """
# 🧠 Entraînez votre propre SLM avec Aricate v4 (Clemylia)
Bienvenue sur l'interface d'entraînement d'Aricate v4 ! Suivez les étapes ci-dessous pour créer et publier votre propre Small Language Model (SLM) basé sur votre dataset personnalisée.
**Étapes à suivre :**
1. **Authentification :** Entrez votre Token et Nom d'utilisateur Hugging Face. **Le token doit avoir la permission "Write" (Écriture).**
2. **Dataset :** Fournissez le nom du dépôt Hugging Face contenant votre dataset.
3. **Colonnes :** Indiquez les noms exacts des colonnes pour les questions et les réponses (par défaut : `question` et `reponse`).
4. **Nom du Modèle :** Choisissez le nom de votre futur modèle (il sera publié sous `votre_nom_utilisateur/nom_du_modèle`).
5. **Hyperparamètres :** Définissez le nombre d'époques.
6. **Lancement :** Appuyez sur le bouton et attendez la fin de l'entraînement et de la publication !
"""
# Création des blocs d'interface
with gr.Blocks(title="Aricate v4 Trainer") as demo:
gr.Markdown(description)
# --- Section d'Authentification et de Publication ---
with gr.Row():
hf_token_input = gr.Textbox(
label="1. Token d'Accès Hugging Face (avec permission 'Write')",
type="password",
placeholder="hf_xxxxxxxxxxxxxxxxxxxxxxxxxx",
info="Token pour l'authentification et la publication (NE PAS PARTAGER !)"
)
hf_user_input = gr.Textbox(
label="2. Votre Nom d'Utilisateur Hugging Face",
placeholder="Clemylia",
info="Le modèle sera publié sur ce compte."
)
# --- Section Dataset ---
gr.Markdown("### 🔍 Configuration de la Dataset")
with gr.Row():
dataset_name_input = gr.Textbox(
label="3. Nom du Dépôt Dataset (ex: Clemylia/Melta-revive)",
placeholder="le_nom_de_votre_dataset",
info="Dépôt public Hugging Face (il doit avoir un split 'train')."
)
question_col_input = gr.Textbox(
label="4. Nom de la Colonne 'Question'",
value="question",
placeholder="question",
info="Nom exact de la colonne contenant les questions."
)
response_col_input = gr.Textbox(
label="5. Nom de la Colonne 'Réponse'",
value="reponse",
placeholder="reponse",
info="Nom exact de la colonne contenant les réponses."
)
# --- Section Modèle et Hyperparamètres ---
gr.Markdown("### ⚙️ Configuration du Modèle et Entraînement")
with gr.Row():
model_name_input = gr.Textbox(
label="6. Nom Final du Modèle (sur Hugging Face)",
placeholder="mon-super-slm-aricate",
info="Sera publié comme 'utilisateur/nom-final'."
)
num_epochs_input = gr.Slider(
label="7. Nombre d'Époques d'Entraînement",
minimum=1,
maximum=50,
step=1,
value=10,
info="Plus d'époques = plus long, mais peut donner de meilleurs résultats (attention à l'overfitting)."
)
# --- Bouton et Sortie ---
train_button = gr.Button("🚀 Entraîner mon propre SLM avec Aricate v4", variant="primary")
# Zone de sortie pour afficher la progression et les messages
output_log = gr.Textbox(
label="Console d'Entraînement et Log de Publication",
lines=15,
autoscroll=True,
interactive=False
)
# Lien entre le bouton et la fonction Python
train_button.click(
fn=train_aricate_model,
inputs=[
hf_token_input,
hf_user_input,
dataset_name_input,
question_col_input,
response_col_input,
model_name_input,
num_epochs_input
],
outputs=output_log
)
# Lancement de l'application Gradio
if __name__ == "__main__":
demo.launch()