Clemylia commited on
Commit
b88d755
·
verified ·
1 Parent(s): d6b3035

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -0
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import collections
9
+ from datasets import load_dataset
10
+ from huggingface_hub import PyTorchModelHubMixin, HfApi, login
11
+ import os
12
+ import time
13
+ import json
14
+ import heapq
15
+ from safetensors.torch import save_file as save_safetensors_file
16
+ import gradio as gr
17
+ import sys # Pour la redirection de la sortie
18
+
19
+ # ==============================================================================
20
+ # ARCHITECTURE ARICATE V4 (Intégrée)
21
+ # ==============================================================================
22
+
23
+ # --- A. WordTokenizer ---
24
+ class WordTokenizer:
25
+ """Tokenizer simple pour l'architecture Aricate."""
26
+ def __init__(self, texts):
27
+ all_words = []
28
+ for text in texts:
29
+ # S'assurer que 'text' est une chaîne de caractères avant de l'opérer
30
+ if isinstance(text, str):
31
+ words = text.lower().split()
32
+ all_words.extend(words)
33
+
34
+ word_counts = collections.Counter(all_words)
35
+ sorted_words = [word for word, count in word_counts.most_common()]
36
+
37
+ self.special_tokens = {
38
+ '<pad>': 0,
39
+ '<unk>': 1,
40
+ '<eos>': 2,
41
+ '<sep>': 3,
42
+ }
43
+
44
+ self.word_to_id = self.special_tokens.copy()
45
+ next_id = len(self.special_tokens)
46
+
47
+ for word in sorted_words:
48
+ if word not in self.word_to_id:
49
+ self.word_to_id[word] = next_id
50
+ next_id += 1
51
+
52
+ self.id_to_word = {id: word for word, id in self.word_to_id.items()}
53
+ self.vocab_size = len(self.word_to_id)
54
+ print(f"Tokenisation effectuée. Taille du vocabulaire : {self.vocab_size}")
55
+
56
+ def encode(self, text, add_eos=False):
57
+ words = text.lower().split()
58
+ if add_eos:
59
+ words.append('<eos>')
60
+
61
+ ids = [self.word_to_id.get(word, self.word_to_id['<unk>']) for word in words]
62
+ return ids
63
+
64
+ def decode(self, ids):
65
+ words = [self.id_to_word.get(id, '<unk>') for id in ids]
66
+ return " ".join(word for word in words if word not in ['<pad>', '<unk>', '<eos>', '<sep>'])
67
+
68
+ # --- B. AricateAttentionLayer ---
69
+ class AricateAttentionLayer(nn.Module):
70
+ """Couche d'Attention Additive (Bahdanau)."""
71
+ def __init__(self, hidden_dim):
72
+ super(AricateAttentionLayer, self).__init__()
73
+ self.W = nn.Linear(hidden_dim, hidden_dim)
74
+ self.U = nn.Linear(hidden_dim, hidden_dim)
75
+ self.V = nn.Linear(hidden_dim, 1, bias=False)
76
+ def forward(self, rnn_outputs, last_hidden):
77
+ last_hidden_expanded = last_hidden.unsqueeze(1)
78
+ energy = torch.tanh(self.W(rnn_outputs) + self.U(last_hidden_expanded))
79
+ attention_weights_raw = self.V(energy).squeeze(2)
80
+ attention_weights = F.softmax(attention_weights_raw, dim=1)
81
+ context_vector = torch.sum(rnn_outputs * attention_weights.unsqueeze(2), dim=1)
82
+ return context_vector
83
+
84
+ # --- C. AricateModel V4 ---
85
+ class AricateModel(nn.Module, PyTorchModelHubMixin):
86
+ """Architecture Aricate V4. Hérite de PyTorchModelHubMixin pour la sauvegarde et la publication."""
87
+ def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int, num_layers: int = 1, config: dict = None):
88
+ super(AricateModel, self).__init__()
89
+
90
+ if config is not None:
91
+ vocab_size = config.get("vocab_size", vocab_size)
92
+ embedding_dim = config.get("embedding_dim", embedding_dim)
93
+ hidden_dim = config.get("hidden_dim", hidden_dim)
94
+ num_layers = config.get("num_layers", num_layers)
95
+
96
+ self.vocab_size = vocab_size
97
+ self.embedding_dim = embedding_dim
98
+ self.hidden_dim = hidden_dim
99
+ self.num_layers = num_layers
100
+
101
+ self.word_embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
102
+ self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
103
+ self.attention = AricateAttentionLayer(hidden_dim)
104
+ self.hidden_to_vocab = nn.Linear(hidden_dim * 2, vocab_size)
105
+
106
+ def forward(self, input_words):
107
+ embeds = self.word_embeddings(input_words)
108
+ rnn_out, hn = self.rnn(embeds)
109
+ last_hidden = hn[-1]
110
+ context_vector = self.attention(rnn_out, last_hidden)
111
+ combined_features = torch.cat((context_vector, last_hidden), dim=1)
112
+ logits = self.hidden_to_vocab(combined_features)
113
+ return logits
114
+
115
+ # --- D. Fonction de Génération (Simplifiée pour l'espace) ---
116
+ # NOTE: J'ai retiré la fonction de génération pour ne pas alourdir l'application Gradio principale et me concentrer sur l'entraînement/publication.
117
+ # Dans un Space, il est préférable d'avoir une démo séparée après l'entraînement.
118
+ # Je garde le Dataset car c'est nécessaire.
119
+
120
+ # --- Nouvelle Classe PyTorch Dataset ---
121
+ class AricateDataset(Dataset):
122
+ """Dataset personnalisé pour PyTorch."""
123
+ def __init__(self, X_data, Y_data):
124
+ self.X = X_data
125
+ self.Y = Y_data
126
+
127
+ def __len__(self):
128
+ return len(self.X)
129
+
130
+ def __getitem__(self, idx):
131
+ return self.X[idx], self.Y[idx]
132
+
133
+ # ==============================================================================
134
+ # FONCTION D'ENTRAÎNEMENT ADAPTÉE POUR GRADIO
135
+ # ==============================================================================
136
+
137
+ def train_aricate_model(
138
+ hf_token: str,
139
+ hf_user: str,
140
+ dataset_name: str,
141
+ question_col: str,
142
+ response_col: str,
143
+ model_name: str,
144
+ num_epochs: int
145
+ ):
146
+ """
147
+ Fonction principale d'entraînement adaptée pour Gradio.
148
+
149
+ Elle prend les entrées de l'utilisateur, configure Aricate v4,
150
+ lance l'entraînement et publie le modèle sur Hugging Face.
151
+ """
152
+
153
+ # Rediriger la sortie standard vers la console Gradio
154
+ sys.stdout.flush()
155
+ print(f"\n{'='*50}\n>>> DÉBUT DU PROCESSUS D'ENTRAÎNEMENT Aricat v4 <<<\n{'='*50}")
156
+
157
+ try:
158
+ # --- 0. Configuration & Connexion Hugging Face ---
159
+ # Paramètres fixes (peuvent être ajustés si nécessaire)
160
+ EMBEDDING_DIM = 64
161
+ HIDDEN_DIM = 128
162
+ NUM_LAYERS = 2
163
+ BATCH_SIZE = 128
164
+ LEARNING_RATE = 0.005
165
+
166
+ # Configuration de l'appareil
167
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
+ print(f"Appareil d'entraînement sélectionné: {device}")
169
+
170
+ # Connexion Hugging Face via le token
171
+ login(token=hf_token, add_to_git_credential=False)
172
+ REPO_ID = f"{hf_user}/{model_name}"
173
+ print(f"Connexion Hugging Face établie. Le modèle sera publié sous le dépôt: {REPO_ID}")
174
+
175
+ print(f"--- Lancement de l'Entraînement du SLM '{model_name}' (Aricate) ---")
176
+
177
+ # 1. Préparation des données
178
+ DATASET_SPLIT = 'train'
179
+ print(f"Chargement de la dataset '{dataset_name}' (split '{DATASET_SPLIT}')...")
180
+ try:
181
+ dataset = load_dataset(dataset_name, split=DATASET_SPLIT)
182
+ except Exception as e:
183
+ raise ValueError(f"Erreur lors du chargement de la dataset '{dataset_name}'. Vérifiez le nom du dépôt. Erreur: {e}")
184
+
185
+ # Construction du corpus en utilisant les colonnes spécifiées par l'utilisateur
186
+ try:
187
+ corpus_raw = [f"{ex[question_col]} <sep> {ex[response_col]}" for ex in dataset]
188
+ except KeyError as e:
189
+ raise KeyError(f"Colonne introuvable dans la dataset. Vérifiez les noms de colonnes : {e}. Les colonnes de votre dataset sont : {dataset.column_names}")
190
+
191
+ tokenizer = WordTokenizer(corpus_raw)
192
+
193
+ train_data_X = []
194
+ train_data_Y = []
195
+
196
+ for item in dataset:
197
+ q = item[question_col]
198
+ r = item[response_col]
199
+ full_seq_ids = tokenizer.encode(f"{q} <sep> {r}", add_eos=True)
200
+ for i in range(1, len(full_seq_ids)):
201
+ X = full_seq_ids[:i]
202
+ Y = full_seq_ids[i]
203
+ train_data_X.append(X)
204
+ train_data_Y.append(Y)
205
+
206
+ max_len = max(len(x) for x in train_data_X)
207
+ padded_X = []
208
+ for x in train_data_X:
209
+ padding_needed = max_len - len(x)
210
+ # Ajout du padding au DÉBUT de la séquence (convention de certains modèles pour l'alignement)
211
+ padded_X.append([tokenizer.special_tokens['<pad>']] * padding_needed + x)
212
+
213
+ X_train_tensor = torch.tensor(padded_X)
214
+ Y_train_tensor = torch.tensor(train_data_Y)
215
+ VOCAB_SIZE = tokenizer.vocab_size
216
+
217
+ print(f"Dataset chargée. Nombre de paires d'entraînement: {len(Y_train_tensor)}")
218
+ print(f"Taille du vocabulaire total: {VOCAB_SIZE}")
219
+ print(f"Longueur maximale d'entrée (max_len): {max_len}")
220
+
221
+ aricate_dataset = AricateDataset(X_train_tensor, Y_train_tensor)
222
+ train_loader = DataLoader(
223
+ dataset=aricate_dataset,
224
+ batch_size=BATCH_SIZE,
225
+ shuffle=True,
226
+ num_workers=0 # Mis à 0 pour éviter des problèmes de multi-processus sur certains environnements HF Space
227
+ )
228
+ print(f"Nombre de batches par époque : {len(train_loader)}")
229
+
230
+ # 2. Initialisation du Modèle
231
+ model_config = {
232
+ "vocab_size": VOCAB_SIZE,
233
+ "embedding_dim": EMBEDDING_DIM,
234
+ "hidden_dim": HIDDEN_DIM,
235
+ "num_layers": NUM_LAYERS
236
+ }
237
+ model = AricateModel(**model_config).to(device)
238
+ loss_function = nn.CrossEntropyLoss()
239
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
240
+
241
+ # 3. Entraînement
242
+ print(f"\nDébut de l'entraînement pour {num_epochs} époques avec un BATCH_SIZE de {BATCH_SIZE}...")
243
+ start_time = time.time()
244
+
245
+ for epoch in range(num_epochs):
246
+ model.train()
247
+ total_loss = 0.0
248
+
249
+ for batch_X, batch_Y in train_loader:
250
+ batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
251
+
252
+ optimizer.zero_grad()
253
+ logits = model(batch_X)
254
+ loss = loss_function(logits, batch_Y)
255
+ loss.backward()
256
+ optimizer.step()
257
+ total_loss += loss.item() * batch_X.size(0)
258
+
259
+ avg_loss = total_loss / len(aricate_dataset)
260
+
261
+ # Mise à jour immédiate du statut
262
+ yield f"Entraînement en cours... Époque [{epoch+1}/{num_epochs}], Perte Moyenne: {avg_loss:.4f}"
263
+
264
+ end_time = time.time()
265
+ yield f"Entraînement terminé ! Durée: {(end_time - start_time):.2f}s. Début de la publication..."
266
+ print(f"\nEntraînement terminé ! Durée: {(end_time - start_time):.2f}s 🎉")
267
+
268
+
269
+ # 4. Sauvegarde et Publication sur Hugging Face
270
+ model.to("cpu")
271
+ print("\n" + "="*50)
272
+ print(">>> SAUVEGARDE ET PUBLICATION SUR HUGGING FACE <<<")
273
+ print("="*50)
274
+
275
+ save_directory = f"./{model_name}_local_save"
276
+ os.makedirs(save_directory, exist_ok=True)
277
+
278
+ model.save_pretrained(save_directory)
279
+ print(f"Modèle sauvegardé localement dans: {save_directory}")
280
+
281
+ tokenizer_path = os.path.join(save_directory, "aricate_tokenizer.txt")
282
+ with open(tokenizer_path, 'w', encoding='utf-8') as f:
283
+ json.dump(tokenizer.word_to_id, f, ensure_ascii=False)
284
+ print(f"Tokenizer (vocabulaire) sauvegardé dans: {tokenizer_path}")
285
+
286
+ # Publication
287
+ model.push_to_hub(
288
+ repo_id=REPO_ID,
289
+ commit_message=f"Modèle entraîné via Aricate v4 Space. Époques: {num_epochs}",
290
+ config=model_config
291
+ )
292
+ HfApi().upload_file(
293
+ path_or_fileobj=tokenizer_path,
294
+ path_in_repo="aricate_tokenizer.txt",
295
+ repo_id=REPO_ID,
296
+ repo_type="model",
297
+ commit_message="Update Aricate custom tokenizer vocabulary."
298
+ )
299
+
300
+ final_message = f"\n✅ Publication réussie ! Votre modèle '{model_name}' est disponible sur : https://huggingface.co/{REPO_ID}"
301
+ print(final_message)
302
+ yield final_message # Message final pour l'interface Gradio
303
+
304
+ except Exception as e:
305
+ error_message = f"\n❌ ERREUR CRITIQUE. L'entraînement ou la publication a échoué. Détail: {e}"
306
+ print(error_message)
307
+ yield error_message # Message d'erreur pour l'interface Gradio
308
+
309
+
310
+ # ==============================================================================
311
+ # INTERFACE GRADIO
312
+ # ==============================================================================
313
+
314
+ # Description détaillée pour l'utilisateur
315
+ description = """
316
+ # 🧠 Entraînez votre propre SLM avec Aricate v4 (Clemylia)
317
+ Bienvenue sur l'interface d'entraînement d'Aricate v4 ! Suivez les étapes ci-dessous pour créer et publier votre propre Small Language Model (SLM) basé sur votre dataset personnalisée.
318
+
319
+ **Étapes à suivre :**
320
+
321
+ 1. **Authentification :** Entrez votre Token et Nom d'utilisateur Hugging Face. **Le token doit avoir la permission "Write" (Écriture).**
322
+ 2. **Dataset :** Fournissez le nom du dépôt Hugging Face contenant votre dataset.
323
+ 3. **Colonnes :** Indiquez les noms exacts des colonnes pour les questions et les réponses (par défaut : `question` et `reponse`).
324
+ 4. **Nom du Modèle :** Choisissez le nom de votre futur modèle (il sera publié sous `votre_nom_utilisateur/nom_du_modèle`).
325
+ 5. **Hyperparamètres :** Définissez le nombre d'époques.
326
+ 6. **Lancement :** Appuyez sur le bouton et attendez la fin de l'entraînement et de la publication !
327
+ """
328
+
329
+ # Création des blocs d'interface
330
+ with gr.Blocks(title="Aricate v4 Trainer") as demo:
331
+ gr.Markdown(description)
332
+
333
+ # --- Section d'Authentification et de Publication ---
334
+ with gr.Row():
335
+ hf_token_input = gr.Textbox(
336
+ label="1. Token d'Accès Hugging Face (avec permission 'Write')",
337
+ type="password",
338
+ placeholder="hf_xxxxxxxxxxxxxxxxxxxxxxxxxx",
339
+ info="Token pour l'authentification et la publication (NE PAS PARTAGER !)"
340
+ )
341
+ hf_user_input = gr.Textbox(
342
+ label="2. Votre Nom d'Utilisateur Hugging Face",
343
+ placeholder="Clemylia",
344
+ info="Le modèle sera publié sur ce compte."
345
+ )
346
+
347
+ # --- Section Dataset ---
348
+ gr.Markdown("### 🔍 Configuration de la Dataset")
349
+ with gr.Row():
350
+ dataset_name_input = gr.Textbox(
351
+ label="3. Nom du Dépôt Dataset (ex: Clemylia/Melta-revive)",
352
+ placeholder="le_nom_de_votre_dataset",
353
+ info="Dépôt public Hugging Face (il doit avoir un split 'train')."
354
+ )
355
+ question_col_input = gr.Textbox(
356
+ label="4. Nom de la Colonne 'Question'",
357
+ value="question",
358
+ placeholder="question",
359
+ info="Nom exact de la colonne contenant les questions."
360
+ )
361
+ response_col_input = gr.Textbox(
362
+ label="5. Nom de la Colonne 'Réponse'",
363
+ value="reponse",
364
+ placeholder="reponse",
365
+ info="Nom exact de la colonne contenant les réponses."
366
+ )
367
+
368
+ # --- Section Modèle et Hyperparamètres ---
369
+ gr.Markdown("### ⚙️ Configuration du Modèle et Entraînement")
370
+ with gr.Row():
371
+ model_name_input = gr.Textbox(
372
+ label="6. Nom Final du Modèle (sur Hugging Face)",
373
+ placeholder="mon-super-slm-aricate",
374
+ info="Sera publié comme 'utilisateur/nom-final'."
375
+ )
376
+ num_epochs_input = gr.Slider(
377
+ label="7. Nombre d'Époques d'Entraînement",
378
+ minimum=1,
379
+ maximum=50,
380
+ step=1,
381
+ value=10,
382
+ info="Plus d'époques = plus long, mais peut donner de meilleurs résultats (attention à l'overfitting)."
383
+ )
384
+
385
+ # --- Bouton et Sortie ---
386
+ train_button = gr.Button("🚀 Entraîner mon propre SLM avec Aricate v4", variant="primary")
387
+
388
+ # Zone de sortie pour afficher la progression et les messages
389
+ output_log = gr.Textbox(
390
+ label="Console d'Entraînement et Log de Publication",
391
+ lines=15,
392
+ autoscroll=True,
393
+ interactive=False
394
+ )
395
+
396
+ # Lien entre le bouton et la fonction Python
397
+ train_button.click(
398
+ fn=train_aricate_model,
399
+ inputs=[
400
+ hf_token_input,
401
+ hf_user_input,
402
+ dataset_name_input,
403
+ question_col_input,
404
+ response_col_input,
405
+ model_name_input,
406
+ num_epochs_input
407
+ ],
408
+ outputs=output_log
409
+ )
410
+
411
+ # Lancement de l'application Gradio
412
+ if __name__ == "__main__":
413
+ demo.launch()